提交 78519771 编写于 作者: H huangqipeng

[MLU]adapt mlu device for running dbnet network

上级 077196f3
Global: Global:
use_gpu: true use_gpu: true
use_xpu: false use_xpu: false
use_mlu: false
epoch_num: 1200 epoch_num: 1200
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
......
...@@ -114,7 +114,7 @@ def merge_config(config, opts): ...@@ -114,7 +114,7 @@ def merge_config(config, opts):
return config return config
def check_device(use_gpu, use_xpu=False, use_npu=False): def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False):
""" """
Log error and exit when set use_gpu=true in paddlepaddle Log error and exit when set use_gpu=true in paddlepaddle
cpu version. cpu version.
...@@ -137,6 +137,9 @@ def check_device(use_gpu, use_xpu=False, use_npu=False): ...@@ -137,6 +137,9 @@ def check_device(use_gpu, use_xpu=False, use_npu=False):
if use_npu and not paddle.device.is_compiled_with_npu(): if use_npu and not paddle.device.is_compiled_with_npu():
print(err.format("use_npu", "npu", "npu", "use_npu")) print(err.format("use_npu", "npu", "npu", "use_npu"))
sys.exit(1) sys.exit(1)
if use_mlu and not paddle.device.is_compiled_with_mlu():
print(err.format("use_mlu", "mlu", "mlu", "use_mlu"))
sys.exit(1)
except Exception as e: except Exception as e:
pass pass
...@@ -618,6 +621,7 @@ def preprocess(is_train=False): ...@@ -618,6 +621,7 @@ def preprocess(is_train=False):
use_gpu = config['Global'].get('use_gpu', False) use_gpu = config['Global'].get('use_gpu', False)
use_xpu = config['Global'].get('use_xpu', False) use_xpu = config['Global'].get('use_xpu', False)
use_npu = config['Global'].get('use_npu', False) use_npu = config['Global'].get('use_npu', False)
use_mlu = config['Global'].get('use_mlu', False)
alg = config['Architecture']['algorithm'] alg = config['Architecture']['algorithm']
assert alg in [ assert alg in [
...@@ -632,10 +636,12 @@ def preprocess(is_train=False): ...@@ -632,10 +636,12 @@ def preprocess(is_train=False):
device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0)) device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
elif use_npu: elif use_npu:
device = 'npu:{0}'.format(os.getenv('FLAGS_selected_npus', 0)) device = 'npu:{0}'.format(os.getenv('FLAGS_selected_npus', 0))
elif use_mlu:
device = 'mlu:{0}'.format(os.getenv('FLAGS_selected_mlus', 0))
else: else:
device = 'gpu:{}'.format(dist.ParallelEnv() device = 'gpu:{}'.format(dist.ParallelEnv()
.dev_id) if use_gpu else 'cpu' .dev_id) if use_gpu else 'cpu'
check_device(use_gpu, use_xpu, use_npu) check_device(use_gpu, use_xpu, use_npu, use_mlu)
device = paddle.set_device(device) device = paddle.set_device(device)
......
...@@ -149,10 +149,11 @@ def main(config, device, logger, vdl_writer): ...@@ -149,10 +149,11 @@ def main(config, device, logger, vdl_writer):
amp_level = config["Global"].get("amp_level", 'O2') amp_level = config["Global"].get("amp_level", 'O2')
amp_custom_black_list = config['Global'].get('amp_custom_black_list', []) amp_custom_black_list = config['Global'].get('amp_custom_black_list', [])
if use_amp: if use_amp:
AMP_RELATED_FLAGS_SETTING = { AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
'FLAGS_cudnn_batchnorm_spatial_persistent': 1, if paddle.is_compiled_with_cuda():
'FLAGS_max_inplace_grad_add': 8, AMP_RELATED_FLAGS_SETTING.update({
} 'FLAGS_cudnn_batchnorm_spatial_persistent': 1
})
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
scale_loss = config["Global"].get("scale_loss", 1.0) scale_loss = config["Global"].get("scale_loss", 1.0)
use_dynamic_loss_scaling = config["Global"].get( use_dynamic_loss_scaling = config["Global"].get(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册