提交 49ecf9c3 编写于 作者: Z zhangyikun02

add use_xpu config for det_mv3_db.yml

上级 d6ec303e
Global: Global:
use_gpu: true use_gpu: true
use_xpu: false
epoch_num: 1200 epoch_num: 1200
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
......
...@@ -130,6 +130,25 @@ def check_gpu(use_gpu): ...@@ -130,6 +130,25 @@ def check_gpu(use_gpu):
pass pass
def check_xpu(use_xpu):
"""
Log error and exit when set use_xpu=true in paddlepaddle
cpu/gpu version.
"""
err = "Config use_xpu cannot be set as true while you are " \
"using paddlepaddle cpu/gpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-xpu to run model on XPU \n" \
"\t2. Set use_xpu as false in config file to run " \
"model on CPU/GPU"
try:
if use_xpu and not paddle.is_compiled_with_xpu():
print(err)
sys.exit(1)
except Exception as e:
pass
def train(config, def train(config,
train_dataloader, train_dataloader,
valid_dataloader, valid_dataloader,
...@@ -512,6 +531,12 @@ def preprocess(is_train=False): ...@@ -512,6 +531,12 @@ def preprocess(is_train=False):
use_gpu = config['Global']['use_gpu'] use_gpu = config['Global']['use_gpu']
check_gpu(use_gpu) check_gpu(use_gpu)
# check if set use_xpu=True in paddlepaddle cpu/gpu version
use_xpu = False
if 'use_xpu' in config['Global']:
use_xpu = config['Global']['use_xpu']
check_xpu(use_xpu)
alg = config['Architecture']['algorithm'] alg = config['Architecture']['algorithm']
assert alg in [ assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
...@@ -519,7 +544,11 @@ def preprocess(is_train=False): ...@@ -519,7 +544,11 @@ def preprocess(is_train=False):
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM' 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM'
] ]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = 'cpu'
if use_gpu:
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id)
if use_xpu:
device = 'xpu'
device = paddle.set_device(device) device = paddle.set_device(device)
config['Global']['distributed'] = dist.get_world_size() != 1 config['Global']['distributed'] = dist.get_world_size() != 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册