提交 49ecf9c3 编写于 作者: Z zhangyikun02

add use_xpu config for det_mv3_db.yml

上级 d6ec303e
Global:
use_gpu: true
use_xpu: false
epoch_num: 1200
log_smooth_window: 20
print_batch_step: 10
......
......@@ -130,6 +130,25 @@ def check_gpu(use_gpu):
pass
def check_xpu(use_xpu):
"""
Log error and exit when set use_xpu=true in paddlepaddle
cpu/gpu version.
"""
err = "Config use_xpu cannot be set as true while you are " \
"using paddlepaddle cpu/gpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-xpu to run model on XPU \n" \
"\t2. Set use_xpu as false in config file to run " \
"model on CPU/GPU"
try:
if use_xpu and not paddle.is_compiled_with_xpu():
print(err)
sys.exit(1)
except Exception as e:
pass
def train(config,
train_dataloader,
valid_dataloader,
......@@ -512,6 +531,12 @@ def preprocess(is_train=False):
use_gpu = config['Global']['use_gpu']
check_gpu(use_gpu)
# check if set use_xpu=True in paddlepaddle cpu/gpu version
use_xpu = False
if 'use_xpu' in config['Global']:
use_xpu = config['Global']['use_xpu']
check_xpu(use_xpu)
alg = config['Architecture']['algorithm']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
......@@ -519,7 +544,11 @@ def preprocess(is_train=False):
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM'
]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = 'cpu'
if use_gpu:
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id)
if use_xpu:
device = 'xpu'
device = paddle.set_device(device)
config['Global']['distributed'] = dist.get_world_size() != 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册