未验证 提交 d8a8ca81 编写于 作者: X xiaoting 提交者: GitHub

support xpu (#6382)

上级 453ae6bb
...@@ -34,6 +34,7 @@ def init_args(): ...@@ -34,6 +34,7 @@ def init_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# params for prediction engine # params for prediction engine
parser.add_argument("--use_gpu", type=str2bool, default=True) parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--use_xpu", type=str2bool, default=False)
parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--min_subgraph_size", type=int, default=15) parser.add_argument("--min_subgraph_size", type=int, default=15)
...@@ -285,6 +286,8 @@ def create_predictor(args, mode, logger): ...@@ -285,6 +286,8 @@ def create_predictor(args, mode, logger):
config.set_trt_dynamic_shape_info( config.set_trt_dynamic_shape_info(
min_input_shape, max_input_shape, opt_input_shape) min_input_shape, max_input_shape, opt_input_shape)
elif args.use_xpu:
config.enable_xpu(10 * 1024 * 1024)
else: else:
config.disable_gpu() config.disable_gpu()
if hasattr(args, "cpu_threads"): if hasattr(args, "cpu_threads"):
......
...@@ -112,20 +112,25 @@ def merge_config(config, opts): ...@@ -112,20 +112,25 @@ def merge_config(config, opts):
return config return config
def check_gpu(use_gpu): def check_device(use_gpu, use_xpu=False):
""" """
Log error and exit when set use_gpu=true in paddlepaddle Log error and exit when set use_gpu=true in paddlepaddle
cpu version. cpu version.
""" """
err = "Config use_gpu cannot be set as true while you are " \ err = "Config {} cannot be set as true while your paddle " \
"using paddlepaddle cpu version ! \nPlease try: \n" \ "is not compiled with {} ! \nPlease try: \n" \
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \ "\t1. Install paddlepaddle to run model on {} \n" \
"\t2. Set use_gpu as false in config file to run " \ "\t2. Set {} as false in config file to run " \
"model on CPU" "model on CPU"
try: try:
if use_gpu and use_xpu:
print("use_xpu and use_gpu can not both be ture.")
if use_gpu and not paddle.is_compiled_with_cuda(): if use_gpu and not paddle.is_compiled_with_cuda():
print(err) print(err.format("use_gpu", "cuda", "gpu", "use_gpu"))
sys.exit(1)
if use_xpu and not paddle.device.is_compiled_with_xpu():
print(err.format("use_xpu", "xpu", "xpu", "use_xpu"))
sys.exit(1) sys.exit(1)
except Exception as e: except Exception as e:
pass pass
...@@ -301,6 +306,7 @@ def train(config, ...@@ -301,6 +306,7 @@ def train(config,
stats['lr'] = lr stats['lr'] = lr
train_stats.update(stats) train_stats.update(stats)
if log_writer is not None and dist.get_rank() == 0: if log_writer is not None and dist.get_rank() == 0:
log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step) log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step)
...@@ -547,7 +553,7 @@ def preprocess(is_train=False): ...@@ -547,7 +553,7 @@ def preprocess(is_train=False):
# check if set use_gpu=True in paddlepaddle cpu version # check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu'] use_gpu = config['Global']['use_gpu']
check_gpu(use_gpu) use_xpu = config['Global'].get('use_xpu', False)
# check if set use_xpu=True in paddlepaddle cpu/gpu version # check if set use_xpu=True in paddlepaddle cpu/gpu version
use_xpu = False use_xpu = False
...@@ -562,11 +568,13 @@ def preprocess(is_train=False): ...@@ -562,11 +568,13 @@ def preprocess(is_train=False):
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR' 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'PREN', 'FCE', 'SVTR'
] ]
device = 'cpu'
if use_gpu:
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id)
if use_xpu: if use_xpu:
device = 'xpu' device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
else:
device = 'gpu:{}'.format(dist.ParallelEnv()
.dev_id) if use_gpu else 'cpu'
check_device(use_gpu, use_xpu)
device = paddle.set_device(device) device = paddle.set_device(device)
config['Global']['distributed'] = dist.get_world_size() != 1 config['Global']['distributed'] = dist.get_world_size() != 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册