提交 a246ab94 编写于 作者: T tink2123

support xpu for ocr

上级 e4600832
...@@ -33,6 +33,7 @@ def init_args(): ...@@ -33,6 +33,7 @@ def init_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# params for prediction engine # params for prediction engine
parser.add_argument("--use_gpu", type=str2bool, default=True) parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--use_xpu", type=str2bool, default=False)
parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--min_subgraph_size", type=int, default=15) parser.add_argument("--min_subgraph_size", type=int, default=15)
...@@ -277,6 +278,8 @@ def create_predictor(args, mode, logger): ...@@ -277,6 +278,8 @@ def create_predictor(args, mode, logger):
config.set_trt_dynamic_shape_info( config.set_trt_dynamic_shape_info(
min_input_shape, max_input_shape, opt_input_shape) min_input_shape, max_input_shape, opt_input_shape)
elif args.use_xpu:
config.enable_xpu(10 * 1024 * 1024)
else: else:
config.disable_gpu() config.disable_gpu()
if hasattr(args, "cpu_threads"): if hasattr(args, "cpu_threads"):
...@@ -630,7 +633,6 @@ def get_rotate_crop_image(img, points): ...@@ -630,7 +633,6 @@ def get_rotate_crop_image(img, points):
def check_gpu(use_gpu): def check_gpu(use_gpu):
if use_gpu and not paddle.is_compiled_with_cuda(): if use_gpu and not paddle.is_compiled_with_cuda():
use_gpu = False use_gpu = False
return use_gpu return use_gpu
......
...@@ -128,20 +128,25 @@ def merge_config(config): ...@@ -128,20 +128,25 @@ def merge_config(config):
cur = cur[sub_key] cur = cur[sub_key]
def check_gpu(use_gpu): def check_device(use_gpu, use_xpu=False):
""" """
Log error and exit when set use_gpu=true in paddlepaddle Log error and exit when set use_gpu=true in paddlepaddle
cpu version. cpu version.
""" """
err = "Config use_gpu cannot be set as true while you are " \ err = "Config {} cannot be set as true while your paddle " \
"using paddlepaddle cpu version ! \nPlease try: \n" \ "is not compiled with {} ! \nPlease try: \n" \
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \ "\t1. Install paddlepaddle to run model on {} \n" \
"\t2. Set use_gpu as false in config file to run " \ "\t2. Set {} as false in config file to run " \
"model on CPU" "model on CPU"
try: try:
if use_gpu and use_xpu:
print("use_xpu and use_gpu can not both be ture.")
if use_gpu and not paddle.is_compiled_with_cuda(): if use_gpu and not paddle.is_compiled_with_cuda():
print(err) print(err.format("use_gpu", "cuda", "gpu", "use_gpu"))
sys.exit(1)
if use_xpu and not paddle.device.is_compiled_with_xpu():
print(err.format("use_xpu", "xpu", "xpu", "use_xpu"))
sys.exit(1) sys.exit(1)
except Exception as e: except Exception as e:
pass pass
...@@ -266,7 +271,7 @@ def train(config, ...@@ -266,7 +271,7 @@ def train(config,
stats['lr'] = lr stats['lr'] = lr
train_stats.update(stats) train_stats.update(stats)
if cal_metric_during_train and model_type is not "det": # only rec and cls need if cal_metric_during_train and model_type is not "det": # only rec and cls need
batch = [item.numpy() for item in batch] batch = [item.numpy() for item in batch]
if model_type in ['table', 'kie']: if model_type in ['table', 'kie']:
eval_class(preds, batch) eval_class(preds, batch)
...@@ -497,7 +502,7 @@ def preprocess(is_train=False): ...@@ -497,7 +502,7 @@ def preprocess(is_train=False):
# check if set use_gpu=True in paddlepaddle cpu version # check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu'] use_gpu = config['Global']['use_gpu']
check_gpu(use_gpu) use_xpu = config['Global'].get('use_xpu', False)
alg = config['Architecture']['algorithm'] alg = config['Architecture']['algorithm']
assert alg in [ assert alg in [
...@@ -511,7 +516,13 @@ def preprocess(is_train=False): ...@@ -511,7 +516,13 @@ def preprocess(is_train=False):
windows_not_support_list)) windows_not_support_list))
sys.exit() sys.exit()
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' if use_xpu:
device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
else:
device = 'gpu:{}'.format(dist.ParallelEnv()
.dev_id) if use_gpu else 'cpu'
check_device(use_gpu, use_xpu)
device = paddle.set_device(device) device = paddle.set_device(device)
config['Global']['distributed'] = dist.get_world_size() != 1 config['Global']['distributed'] = dist.get_world_size() != 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册