提交 78519771 编写于 作者: H huangqipeng

[MLU]adapt mlu device for running dbnet network

上级 077196f3
Global:
use_gpu: true
use_xpu: false
use_mlu: false
epoch_num: 1200
log_smooth_window: 20
print_batch_step: 10
......
......@@ -114,7 +114,7 @@ def merge_config(config, opts):
return config
def check_device(use_gpu, use_xpu=False, use_npu=False):
def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False):
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
......@@ -137,6 +137,9 @@ def check_device(use_gpu, use_xpu=False, use_npu=False):
if use_npu and not paddle.device.is_compiled_with_npu():
print(err.format("use_npu", "npu", "npu", "use_npu"))
sys.exit(1)
if use_mlu and not paddle.device.is_compiled_with_mlu():
print(err.format("use_mlu", "mlu", "mlu", "use_mlu"))
sys.exit(1)
except Exception as e:
pass
......@@ -618,6 +621,7 @@ def preprocess(is_train=False):
use_gpu = config['Global'].get('use_gpu', False)
use_xpu = config['Global'].get('use_xpu', False)
use_npu = config['Global'].get('use_npu', False)
use_mlu = config['Global'].get('use_mlu', False)
alg = config['Architecture']['algorithm']
assert alg in [
......@@ -632,10 +636,12 @@ def preprocess(is_train=False):
device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
elif use_npu:
device = 'npu:{0}'.format(os.getenv('FLAGS_selected_npus', 0))
elif use_mlu:
device = 'mlu:{0}'.format(os.getenv('FLAGS_selected_mlus', 0))
else:
device = 'gpu:{}'.format(dist.ParallelEnv()
.dev_id) if use_gpu else 'cpu'
check_device(use_gpu, use_xpu, use_npu)
check_device(use_gpu, use_xpu, use_npu, use_mlu)
device = paddle.set_device(device)
......
......@@ -149,10 +149,11 @@ def main(config, device, logger, vdl_writer):
amp_level = config["Global"].get("amp_level", 'O2')
amp_custom_black_list = config['Global'].get('amp_custom_black_list', [])
if use_amp:
AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
'FLAGS_max_inplace_grad_add': 8,
}
AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
if paddle.is_compiled_with_cuda():
AMP_RELATED_FLAGS_SETTING.update({
'FLAGS_cudnn_batchnorm_spatial_persistent': 1
})
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
scale_loss = config["Global"].get("scale_loss", 1.0)
use_dynamic_loss_scaling = config["Global"].get(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册