未验证 提交 964643ee 编写于 作者: A Aganlengzi 提交者: GitHub

support custom npu (#7461)

上级 cc15a9b6
......@@ -421,8 +421,10 @@ class Trainer(object):
model = self.model
if self.cfg.get('to_static', False):
model = apply_to_static(self.cfg, model)
sync_bn = (getattr(self.cfg, 'norm_type', None) == 'sync_bn' and
(self.cfg.use_gpu or self.cfg.use_mlu) and self._nranks > 1)
sync_bn = (
getattr(self.cfg, 'norm_type', None) == 'sync_bn' and
(self.cfg.use_gpu or self.cfg.use_npu or self.cfg.use_mlu) and
self._nranks > 1)
if sync_bn:
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
......@@ -484,7 +486,8 @@ class Trainer(object):
DataParallel) and use_fused_allreduce_gradients:
with model.no_sync():
with paddle.amp.auto_cast(
enable=self.cfg.use_gpu or self.cfg.use_mlu,
enable=self.cfg.use_gpu or
self.cfg.use_npu or self.cfg.use_mlu,
custom_white_list=self.custom_white_list,
custom_black_list=self.custom_black_list,
level=self.amp_level):
......@@ -498,7 +501,8 @@ class Trainer(object):
list(model.parameters()), None)
else:
with paddle.amp.auto_cast(
enable=self.cfg.use_gpu or self.cfg.use_mlu,
enable=self.cfg.use_gpu or self.cfg.use_npu or
self.cfg.use_mlu,
custom_white_list=self.custom_white_list,
custom_black_list=self.custom_black_list,
level=self.amp_level):
......@@ -612,7 +616,8 @@ class Trainer(object):
# forward
if self.use_amp:
with paddle.amp.auto_cast(
enable=self.cfg.use_gpu or self.cfg.use_mlu,
enable=self.cfg.use_gpu or self.cfg.use_npu or
self.cfg.use_mlu,
custom_white_list=self.custom_white_list,
custom_black_list=self.custom_black_list,
level=self.amp_level):
......@@ -679,7 +684,8 @@ class Trainer(object):
# forward
if self.use_amp:
with paddle.amp.auto_cast(
enable=self.cfg.use_gpu or self.cfg.use_mlu,
enable=self.cfg.use_gpu or self.cfg.use_npu or
self.cfg.use_mlu,
custom_white_list=self.custom_white_list,
custom_black_list=self.custom_black_list,
level=self.amp_level):
......
......@@ -53,16 +53,17 @@ def check_mlu(use_mlu):
def check_npu(use_npu):
"""
Log error and exit when set use_npu=true in paddlepaddle
cpu/gpu/xpu version.
version without paddle-custom-npu installed.
"""
err = "Config use_npu cannot be set as true while you are " \
"using paddlepaddle cpu/gpu/xpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-npu to run model on NPU \n" \
"using paddlepaddle version without paddle-custom-npu " \
"installed! \nPlease try: \n" \
"\t1. Install paddle-custom-npu to run model on NPU \n" \
"\t2. Set use_npu as false in config file to run " \
"model on CPU/GPU/XPU"
"model on other devices supported."
try:
if use_npu and not paddle.is_compiled_with_npu():
if use_npu and not 'npu' in paddle.device.get_all_custom_device_type():
logger.error(err)
sys.exit(1)
except Exception as e:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册