未验证 提交 500f070d 编写于 作者: Q qipengh 提交者: GitHub

[MLU] fix sync_bn of mlu and add unittests (#45707)

* [MLU] fix sync_bn of mlu and add unittests

* [MLU] remove redunant code of pytest
上级 b7d219be
...@@ -159,9 +159,9 @@ class SyncBatchNormMLUKernel : public framework::OpKernel<T> { ...@@ -159,9 +159,9 @@ class SyncBatchNormMLUKernel : public framework::OpKernel<T> {
GetBasePtr(&local_var)); GetBasePtr(&local_var));
Tensor input_count; Tensor input_count;
input_count.mutable_data<T>(phi::make_ddim({1}), ctx.GetPlace()); input_count.mutable_data<MPDType>(phi::make_ddim({1}), ctx.GetPlace());
FillMLUTensorWithHostValue<T>( FillMLUTensorWithHostValue<MPDType>(
ctx, static_cast<T>(x->numel() / C), &input_count); ctx, static_cast<MPDType>(x->numel() / C), &input_count);
Tensor count_all; Tensor count_all;
Tensor mean_all(mean->dtype()); Tensor mean_all(mean->dtype());
...@@ -170,15 +170,23 @@ class SyncBatchNormMLUKernel : public framework::OpKernel<T> { ...@@ -170,15 +170,23 @@ class SyncBatchNormMLUKernel : public framework::OpKernel<T> {
#ifdef PADDLE_WITH_CNCL #ifdef PADDLE_WITH_CNCL
auto &dev_ctx = auto &dev_ctx =
ctx.template device_context<paddle::platform::MLUDeviceContext>(); ctx.template device_context<paddle::platform::MLUDeviceContext>();
auto stream = dev_ctx.stream();
auto *comm = dev_ctx.cncl_comm(); auto *comm = dev_ctx.cncl_comm();
if (comm) { if (comm) {
auto *comm = paddle::platform::CNCLCommContext::Instance() auto cncl_comm = paddle::platform::CNCLCommContext::Instance().Get(
.Get(0, ctx.GetPlace()) 0, ctx.GetPlace());
->comm(); auto *comm = cncl_comm->comm();
auto comm_stream = cncl_comm->stream();
int count; int count;
PADDLE_ENFORCE_MLU_SUCCESS(cnclGetCommCount(&count, comm)); PADDLE_ENFORCE_MLU_SUCCESS(cnclGetCommCount(&count, comm));
count_all.mutable_data<T>(phi::make_ddim({count}), ctx.GetPlace()); count_all.mutable_data<MPDType>(phi::make_ddim({count}),
ctx.GetPlace());
mean_all.mutable_data<MPDType>(phi::make_ddim({count, mean->numel()}),
ctx.GetPlace());
invstd_all.mutable_data<MPDType>(
phi::make_ddim({count, variance->numel()}), ctx.GetPlace());
// before comm_stream exec, need sync compute_stream.
dev_ctx.Wait();
cnclDataType_t dtype = platform::ToCNCLDataType( cnclDataType_t dtype = platform::ToCNCLDataType(
framework::TransToProtoVarType(count_all.dtype())); framework::TransToProtoVarType(count_all.dtype()));
PADDLE_ENFORCE_MLU_SUCCESS(cnclAllGather(GetBasePtr(&input_count), PADDLE_ENFORCE_MLU_SUCCESS(cnclAllGather(GetBasePtr(&input_count),
...@@ -186,12 +194,7 @@ class SyncBatchNormMLUKernel : public framework::OpKernel<T> { ...@@ -186,12 +194,7 @@ class SyncBatchNormMLUKernel : public framework::OpKernel<T> {
1, 1,
dtype, dtype,
comm, comm,
stream)); comm_stream));
mean_all.mutable_data<MPDType>(phi::make_ddim({count, mean->numel()}),
ctx.GetPlace());
invstd_all.mutable_data<MPDType>(
phi::make_ddim({count, variance->numel()}), ctx.GetPlace());
auto cncl_dtype = platform::ToCNCLDataType( auto cncl_dtype = platform::ToCNCLDataType(
framework::TransToProtoVarType(mean_all.dtype())); framework::TransToProtoVarType(mean_all.dtype()));
...@@ -200,14 +203,17 @@ class SyncBatchNormMLUKernel : public framework::OpKernel<T> { ...@@ -200,14 +203,17 @@ class SyncBatchNormMLUKernel : public framework::OpKernel<T> {
local_mean.numel(), local_mean.numel(),
cncl_dtype, cncl_dtype,
comm, comm,
stream)); comm_stream));
PADDLE_ENFORCE_MLU_SUCCESS(cnclAllGather(GetBasePtr(&local_var), PADDLE_ENFORCE_MLU_SUCCESS(cnclAllGather(GetBasePtr(&local_var),
GetBasePtr(&invstd_all), GetBasePtr(&invstd_all),
local_var.numel(), local_var.numel(),
cncl_dtype, cncl_dtype,
comm, comm,
stream)); comm_stream));
// after comm_stream exec, need sync queue for using compute_stream
// correctly.
PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueSync(comm_stream));
#else #else
if (NO_USE_CNCL) { if (NO_USE_CNCL) {
#endif #endif
...@@ -412,12 +418,14 @@ class SyncBatchNormMLUGradKernel : public framework::OpKernel<T> { ...@@ -412,12 +418,14 @@ class SyncBatchNormMLUGradKernel : public framework::OpKernel<T> {
#ifdef PADDLE_WITH_CNCL #ifdef PADDLE_WITH_CNCL
auto &dev_ctx = auto &dev_ctx =
ctx.template device_context<paddle::platform::MLUDeviceContext>(); ctx.template device_context<paddle::platform::MLUDeviceContext>();
auto stream = dev_ctx.stream();
auto *comm = dev_ctx.cncl_comm(); auto *comm = dev_ctx.cncl_comm();
if (comm) { if (comm) {
auto *comm = paddle::platform::CNCLCommContext::Instance() auto cncl_comm =
.Get(0, ctx.GetPlace()) paddle::platform::CNCLCommContext::Instance().Get(0, ctx.GetPlace());
->comm(); auto *comm = cncl_comm->comm();
auto comm_stream = cncl_comm->stream();
// before comm_stream exec, need sync compute_stream.
dev_ctx.Wait();
cnclDataType_t dtype = platform::ToCNCLDataType( cnclDataType_t dtype = platform::ToCNCLDataType(
framework::TransToProtoVarType(numel_count.dtype())); framework::TransToProtoVarType(numel_count.dtype()));
PADDLE_ENFORCE_MLU_SUCCESS(cnclAllReduce(GetBasePtr(&numel_count), PADDLE_ENFORCE_MLU_SUCCESS(cnclAllReduce(GetBasePtr(&numel_count),
...@@ -426,7 +434,7 @@ class SyncBatchNormMLUGradKernel : public framework::OpKernel<T> { ...@@ -426,7 +434,7 @@ class SyncBatchNormMLUGradKernel : public framework::OpKernel<T> {
dtype, dtype,
cnclSum, cnclSum,
comm, comm,
stream)); comm_stream));
auto cncl_dtype = platform::ToCNCLDataType( auto cncl_dtype = platform::ToCNCLDataType(
framework::TransToProtoVarType(sum_dy.dtype())); framework::TransToProtoVarType(sum_dy.dtype()));
...@@ -436,7 +444,7 @@ class SyncBatchNormMLUGradKernel : public framework::OpKernel<T> { ...@@ -436,7 +444,7 @@ class SyncBatchNormMLUGradKernel : public framework::OpKernel<T> {
cncl_dtype, cncl_dtype,
cnclSum, cnclSum,
comm, comm,
stream)); comm_stream));
PADDLE_ENFORCE_MLU_SUCCESS(cnclAllReduce(GetBasePtr(&sum_dy_xmu), PADDLE_ENFORCE_MLU_SUCCESS(cnclAllReduce(GetBasePtr(&sum_dy_xmu),
GetBasePtr(&sum_dy_xmu), GetBasePtr(&sum_dy_xmu),
...@@ -444,7 +452,10 @@ class SyncBatchNormMLUGradKernel : public framework::OpKernel<T> { ...@@ -444,7 +452,10 @@ class SyncBatchNormMLUGradKernel : public framework::OpKernel<T> {
cncl_dtype, cncl_dtype,
cnclSum, cnclSum,
comm, comm,
stream)); comm_stream));
// after comm_stream exec, need sync queue for using compute_stream
// correctly.
PADDLE_ENFORCE_MLU_SUCCESS(cnrtQueueSync(comm_stream));
} }
#endif #endif
......
...@@ -7220,9 +7220,9 @@ def device_guard(device=None): ...@@ -7220,9 +7220,9 @@ def device_guard(device=None):
device, index = device.split(':') device, index = device.split(':')
if device == 'cpu': if device == 'cpu':
raise ValueError("Should not set device id for cpu.") raise ValueError("Should not set device id for cpu.")
if device not in ['cpu', 'gpu', 'npu', 'xpu', '', None]: if device not in ['cpu', 'gpu', 'npu', 'xpu', 'mlu', '', None]:
raise ValueError( raise ValueError(
"The Attr(device) should be 'cpu' 'npu' 'xpu' or 'gpu', and it can also be empty string or None " "The Attr(device) should be 'cpu' 'npu' 'xpu' 'mlu' or 'gpu', and it can also be empty string or None "
"when there is no need to specify device. But received %s" % device) "when there is no need to specify device. But received %s" % device)
if index: if index:
device = ":".join([device, index]) device = ":".join([device, index])
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import contextlib
import unittest
import numpy as np
import six
import pickle
import paddle
import paddle.fluid as fluid
import paddle.fluid.dygraph as dygraph
from paddle.fluid import core
from paddle.fluid.optimizer import SGDOptimizer
from paddle.nn import Conv2D, Linear, SyncBatchNorm
from paddle.fluid.dygraph.base import to_variable
import sys
sys.path.append("..")
from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase
class TestLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act=None):
super(TestLayer, self).__init__()
self._conv = Conv2D(in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
bias_attr=False)
self._sync_batch_norm = SyncBatchNorm(num_filters)
self._conv2 = Conv2D(in_channels=num_filters,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
bias_attr=False)
self._sync_batch_norm2 = SyncBatchNorm(num_filters,
weight_attr=False,
bias_attr=False)
def forward(self, inputs):
y = self._conv(inputs)
y = self._sync_batch_norm(y)
y = self._conv2(y)
y = self._sync_batch_norm2(y)
return y
class TestSyncBatchNorm(TestParallelDyGraphRunnerBase):
def get_model(self):
model = TestLayer(3, 64, 7)
train_reader = paddle.batch(paddle.dataset.flowers.test(use_xmap=False),
batch_size=32,
drop_last=True)
opt = fluid.optimizer.Adam(learning_rate=1e-3,
parameter_list=model.parameters())
return model, train_reader, opt
def run_one_loop(self, model, opt, data):
batch_size = len(data)
dy_x_data = np.array([x[0].reshape(3, 224, 224)
for x in data]).astype('float32')
img = to_variable(dy_x_data)
img.stop_gradient = False
out = model(img)
out = paddle.mean(out)
return out
if __name__ == "__main__":
runtime_main(TestSyncBatchNorm)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
sys.path.append("..")
import unittest
from test_dist_base import TestDistBase
import paddle.fluid as fluid
import os
import subprocess
import pickle
DEFAULT_BATCH_SIZE = 2
flag_name = os.path.splitext(__file__)[0]
print("file: {}".format(flag_name))
class TestParallelDygraphMnistMLU(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._cncl_mode = True
self._dygraph = True
self._enforce_place = "MLU"
def _get_required_envs(self, check_error_log=False, need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"LD_PRELOAD": os.getenv("LD_PRELOAD", ""),
"FLAGS_fraction_of_gpu_memory_to_use": "0.15",
"FLAGS_eager_delete_tensor_gb": "0.0",
"FLAGS_call_stack_level": "2",
"GLOG_v": "2",
"PADDLE_WITH_GLOO": '0',
"BACKEND": "cncl"
}
if check_error_log:
required_envs["GLOG_v"] = "5"
required_envs["GLOG_logtostderr"] = "1"
required_envs["GLOO_LOG_LEVEL"] = "TRACE"
required_envs.update(need_envs)
return required_envs
def _run_local(self,
model,
envs,
check_error_log=False,
batch_size=DEFAULT_BATCH_SIZE,
batch_merge_repeat=1,
log_name="",
devices="1"):
cmd = self._python_interp
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
envs['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '')
cmd += " -m coverage run --branch -p"
cmd += " %s --role trainer --update_method local --lr %f" % (model,
self._lr)
if batch_size != DEFAULT_BATCH_SIZE:
cmd += " --batch_size %d" % batch_size
if batch_merge_repeat > 1:
cmd += " --batch_merge_repeat %d" % batch_merge_repeat
if self._nccl2_reduce_layer:
cmd += " --nccl2_reduce_layer_local_run 1"
if self._use_mlu:
cmd += " --use_mlu"
env_local = {
"FLAGS_selected_mlus": devices,
"PADDLE_TRAINERS_NUM": "1",
"PADDLE_TRAINER_ID": "0"
}
else:
env_local = {'CPU_NUM': '1'}
# not use dgc in single card
if len(devices) > 1 and self._use_dgc:
cmd += " --use_dgc"
if self._accumulate_gradient:
cmd += " --accumulate_gradient"
if self._find_unused_parameters:
cmd += " --find_unused_parameters"
env_local.update(envs)
print("local_cmd: {}, env: {}".format(cmd, env_local))
if check_error_log:
path = "/tmp/local_err_%d.log" % os.getpid()
err_log = open(path, "w")
local_proc = subprocess.Popen(cmd.split(" "),
stdout=subprocess.PIPE,
stderr=err_log,
env=env_local)
else:
local_proc = subprocess.Popen(cmd.split(" "),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env_local)
local_out, local_err = local_proc.communicate()
if check_error_log:
err_log.close()
sys.stderr.write(
'\n--run_local-- trainer 0 stderr file saved in: %s\n' % (path))
sys.stderr.write('local_stderr: %s\n' % local_err)
sys.stderr.write('local_stdout: %s\n' % pickle.loads(local_out))
return pickle.loads(local_out)
def _run_cluster_nccl2(self, model, envs, update_method, check_error_log,
log_name):
# NOTE: we reuse ps_endpoints as nccl2 worker endpoints
worker_endpoints = self._ps_endpoints.split(",")
trainer_num = len(worker_endpoints)
procs = []
pipes = []
for i in range(0, trainer_num):
tr_cmd, tr_env = self._get_nccl2_trainer_cmd(
model, worker_endpoints[i], update_method, i, trainer_num)
tr_env.update(envs)
print("use_hallreduce:{} \ntr{}_cmd:{}, env: {}".format(
self._use_hallreduce, i, tr_cmd, tr_env))
tr_pipe = open("/tmp/tr%d_err_%d.log" % (i, os.getpid()), "w")
sys.stderr.write(
"\n{} going to start process {} with nccl2\n".format(
type(self).__name__, i))
tr_proc = subprocess.Popen(tr_cmd.strip().split(" "),
stdout=subprocess.PIPE,
stderr=tr_pipe,
env=tr_env)
procs.append(tr_proc)
pipes.append(tr_pipe)
outs = []
for i in range(0, trainer_num):
tr_out, tr_err = procs[i].communicate()
outs.append(tr_out)
pipes[i].close()
sys.stderr.write('trainer {} stderr: {}\n'.format(i, tr_err))
sys.stderr.write(
'trainer {} glog file saved in: /tmp/tr{}_err_{}.log \n'.format(
i, i, os.getpid()))
if check_error_log:
print("outs[0]:", pickle.loads(outs[0]))
print("outs[1]:", pickle.loads(outs[1]))
return pickle.loads(outs[0]), pickle.loads(outs[1])
def test_mnist(self):
if fluid.core.is_compiled_with_mlu():
self.check_with_place(
os.path.abspath("parallel_dygraph_sync_batch_norm.py"),
delta=1e-5,
check_error_log=True,
log_name=flag_name)
if __name__ == "__main__":
unittest.main()
...@@ -126,19 +126,19 @@ class TestSyncBatchNormRunnerBase(object): ...@@ -126,19 +126,19 @@ class TestSyncBatchNormRunnerBase(object):
for layout in ["NCHW", "NHWC"]: for layout in ["NCHW", "NHWC"]:
self._compare(args, place, layout, True) self._compare(args, place, layout, True)
# # Test FP16 - @TODO # Test FP16 - @TODO
# self.dtype = np.float16 self.dtype = np.float16
# self.atol = 1e-2 self.atol = 1e-2
# # Test training # Test training
# for place in places: for place in places:
# for layout in ["NCHW", "NHWC"]: for layout in ["NCHW", "NHWC"]:
# self._compare(args, place, layout, False) self._compare(args, place, layout, False)
# # Test inference # Test inference
# for place in places: for place in places:
# for layout in ["NCHW", "NHWC"]: for layout in ["NCHW", "NHWC"]:
# self._compare(args, place, layout, True) self._compare(args, place, layout, True)
sys.stdout.buffer.write( sys.stdout.buffer.write(
pickle.dumps( pickle.dumps(
...@@ -333,8 +333,8 @@ class TestSyncBatchNormRunnerBase(object): ...@@ -333,8 +333,8 @@ class TestSyncBatchNormRunnerBase(object):
self.initCommunicator(startup_prog, rank, nranks, True, self.initCommunicator(startup_prog, rank, nranks, True,
current_endpoint, endpoints) current_endpoint, endpoints)
sys.stderr.write("after init, startup_prog: " + # sys.stderr.write("after init, startup_prog: " +
startup_prog.to_string(True) + "\n") # startup_prog.to_string(True) + "\n")
train_prog.global_seed(SEED) train_prog.global_seed(SEED)
train_prog._sync_with_cpp() train_prog._sync_with_cpp()
startup_prog.global_seed(SEED) startup_prog.global_seed(SEED)
...@@ -344,10 +344,10 @@ class TestSyncBatchNormRunnerBase(object): ...@@ -344,10 +344,10 @@ class TestSyncBatchNormRunnerBase(object):
self.rank = rank self.rank = rank
outs = self.get_model(train_prog, startup_prog, place, layout, SEED, outs = self.get_model(train_prog, startup_prog, place, layout, SEED,
True, only_forward) True, only_forward)
sys.stderr.write("after get_model, train_prog: " + # sys.stderr.write("after get_model, train_prog: " +
train_prog.to_string(True) + "\n") # train_prog.to_string(True) + "\n")
sys.stderr.write("after get_model, startup_prog: " + # sys.stderr.write("after get_model, startup_prog: " +
startup_prog.to_string(True) + "\n") # startup_prog.to_string(True) + "\n")
ops = train_prog.blocks[0].ops ops = train_prog.blocks[0].ops
for i, op in enumerate(ops): for i, op in enumerate(ops):
...@@ -360,8 +360,8 @@ class TestSyncBatchNormRunnerBase(object): ...@@ -360,8 +360,8 @@ class TestSyncBatchNormRunnerBase(object):
sys.stderr.write("op type: " + op.type + "\n") sys.stderr.write("op type: " + op.type + "\n")
op.desc.set_type('sync_batch_norm_grad') op.desc.set_type('sync_batch_norm_grad')
sys.stderr.write("after update sync_batch_norm, train_prog: " + # sys.stderr.write("after update sync_batch_norm, train_prog: " +
train_prog.to_string(True) + "\n") # train_prog.to_string(True) + "\n")
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup_prog) exe.run(startup_prog)
......
...@@ -17,3 +17,5 @@ ...@@ -17,3 +17,5 @@
set -e set -e
MLU_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch test_sync_batch_norm_op_mlu_baseline.py MLU_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch test_sync_batch_norm_op_mlu_baseline.py
MLU_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch test_parallel_dygraph_sync_batch_norm_mlu.py
...@@ -552,6 +552,9 @@ class TestParallelDyGraphRunnerBase(object): ...@@ -552,6 +552,9 @@ class TestParallelDyGraphRunnerBase(object):
elif fluid.core.is_compiled_with_npu(): elif fluid.core.is_compiled_with_npu():
device_id = int(os.getenv("FLAGS_selected_npus", "0")) device_id = int(os.getenv("FLAGS_selected_npus", "0"))
place = fluid.NPUPlace(device_id) place = fluid.NPUPlace(device_id)
elif fluid.core.is_compiled_with_mlu():
device_id = int(os.getenv("FLAGS_selected_mlus", "0"))
place = fluid.MLUPlace(device_id)
else: else:
assert ("Only support CUDAPlace or XPUPlace or CPU(Gloo) for now.") assert ("Only support CUDAPlace or XPUPlace or CPU(Gloo) for now.")
...@@ -565,7 +568,7 @@ class TestParallelDyGraphRunnerBase(object): ...@@ -565,7 +568,7 @@ class TestParallelDyGraphRunnerBase(object):
nranks = len(args.endpoints.split(",")) if args.endpoints else 1 nranks = len(args.endpoints.split(",")) if args.endpoints else 1
#if args.update_method == "nccl2": #if args.update_method == "nccl2":
if args.update_method == "nccl2" or args.update_method == "bkcl" or args.update_method == "hccl": if args.update_method == "nccl2" or args.update_method == "bkcl" or args.update_method == "hccl" or args.update_method == "cncl":
strategy = dygraph.parallel.ParallelStrategy() strategy = dygraph.parallel.ParallelStrategy()
strategy.nranks = nranks strategy.nranks = nranks
strategy.local_rank = args.trainer_id strategy.local_rank = args.trainer_id
...@@ -708,7 +711,7 @@ def runtime_main(test_class): ...@@ -708,7 +711,7 @@ def runtime_main(test_class):
default="local", default="local",
choices=[ choices=[
"pserver", "nccl2", "bkcl", "local", "pserver", "nccl2", "bkcl", "local",
"nccl2_reduce_layer", "gloo", "hccl" "nccl2_reduce_layer", "gloo", "hccl", "cncl"
]) ])
parser.add_argument('--trainer_id', type=int, required=False, default=0) parser.add_argument('--trainer_id', type=int, required=False, default=0)
parser.add_argument('--trainers', type=int, required=False, default=1) parser.add_argument('--trainers', type=int, required=False, default=1)
...@@ -735,6 +738,7 @@ def runtime_main(test_class): ...@@ -735,6 +738,7 @@ def runtime_main(test_class):
parser.add_argument('--use_xpu', action='store_true') parser.add_argument('--use_xpu', action='store_true')
parser.add_argument('--use_dgc', action='store_true') parser.add_argument('--use_dgc', action='store_true')
parser.add_argument('--use_npu', action='store_true') parser.add_argument('--use_npu', action='store_true')
parser.add_argument('--use_mlu', action='store_true')
parser.add_argument('--accumulate_gradient', action='store_true') parser.add_argument('--accumulate_gradient', action='store_true')
parser.add_argument('--find_unused_parameters', action='store_true') parser.add_argument('--find_unused_parameters', action='store_true')
parser.add_argument('--use_reduce', action='store_true') parser.add_argument('--use_reduce', action='store_true')
...@@ -794,20 +798,30 @@ class TestDistBase(unittest.TestCase): ...@@ -794,20 +798,30 @@ class TestDistBase(unittest.TestCase):
self.__use_xpu = False self.__use_xpu = False
self._use_dgc = False self._use_dgc = False
self.__use_npu = False self.__use_npu = False
self._use_mlu = False
elif self._enforce_place == "GPU": elif self._enforce_place == "GPU":
self.__use_cuda = True self.__use_cuda = True
self.__use_xpu = False self.__use_xpu = False
self.__use_npu = False self.__use_npu = False
self._use_mlu = False
elif self._enforce_place == "XPU": elif self._enforce_place == "XPU":
self.__use_cuda = False self.__use_cuda = False
self.__use_xpu = True self.__use_xpu = True
self._use_dgc = False self._use_dgc = False
self.__use_npu = False self.__use_npu = False
self._use_mlu = False
elif self._enforce_place == "NPU": elif self._enforce_place == "NPU":
self.__use_cuda = False self.__use_cuda = False
self.__use_xpu = False self.__use_xpu = False
self._use_dgc = False self._use_dgc = False
self.__use_npu = True self.__use_npu = True
self._use_mlu = False
elif self._enforce_place == "MLU":
self.__use_cuda = False
self.__use_xpu = False
self._use_dgc = False
self.__use_npu = False
self._use_mlu = True
else: else:
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
self.__use_cuda = True self.__use_cuda = True
...@@ -833,6 +847,7 @@ class TestDistBase(unittest.TestCase): ...@@ -833,6 +847,7 @@ class TestDistBase(unittest.TestCase):
self._bkcl_mode = False self._bkcl_mode = False
self._gloo_mode = False # now, support gloo backend self._gloo_mode = False # now, support gloo backend
self._hccl_mode = False self._hccl_mode = False
self._cncl_mode = False
self._pipeline_mode = False self._pipeline_mode = False
self._mp_mode = False self._mp_mode = False
self._diff_batch = False self._diff_batch = False
...@@ -1243,6 +1258,16 @@ class TestDistBase(unittest.TestCase): ...@@ -1243,6 +1258,16 @@ class TestDistBase(unittest.TestCase):
"PADDLE_CURRENT_ENDPOINT": ep, "PADDLE_CURRENT_ENDPOINT": ep,
"GLOG_v": "2", "GLOG_v": "2",
}) })
elif self._use_mlu:
tr_cmd += " --use_mlu"
env.update({
"FLAGS_selected_mlus": "{}".format(trainer_id),
"PADDLE_TRAINERS_NUM": "{}".format(trainer_num),
"PADDLE_TRAINER_ID": "{}".format(trainer_id),
"PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
"PADDLE_CURRENT_ENDPOINT": ep,
"GLOG_v": "4",
})
else: else:
env.update({'CPU_NUM': '1'}) env.update({'CPU_NUM': '1'})
...@@ -1556,7 +1581,13 @@ class TestDistBase(unittest.TestCase): ...@@ -1556,7 +1581,13 @@ class TestDistBase(unittest.TestCase):
update_method='hccl', update_method='hccl',
check_error_log=check_error_log, check_error_log=check_error_log,
log_name=log_name) log_name=log_name)
elif self._cncl_mode:
tr0_losses, tr1_losses = self._run_cluster_nccl2(
model_file,
required_envs,
update_method='cncl',
check_error_log=check_error_log,
log_name=log_name)
elif self._pipeline_mode: elif self._pipeline_mode:
tr0_losses, tr1_losses = self._run_pipeline(model_file, tr0_losses, tr1_losses = self._run_pipeline(model_file,
required_envs, required_envs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册