未验证 提交 39826e0f 编写于 作者: H houj04 提交者: GitHub

add xpu and npu support for animegan_v2_hayao_64. (#1648)

上级 eddc6adf
...@@ -8,9 +8,9 @@ __all__ = ['Model'] ...@@ -8,9 +8,9 @@ __all__ = ['Model']
class Model(): class Model():
# 初始化函数 # 初始化函数
def __init__(self, modelpath, use_gpu=False, use_mkldnn=True, combined=True): def __init__(self, modelpath, use_gpu=False, use_mkldnn=True, combined=True, use_device=None):
# 加载模型预测器 # 加载模型预测器
self.predictor = self.load_model(modelpath, use_gpu, use_mkldnn, combined) self.predictor = self.load_model(modelpath, use_gpu, use_mkldnn, combined, use_device)
# 获取模型的输入输出 # 获取模型的输入输出
self.input_names = self.predictor.get_input_names() self.input_names = self.predictor.get_input_names()
...@@ -18,18 +18,16 @@ class Model(): ...@@ -18,18 +18,16 @@ class Model():
self.input_handle = self.predictor.get_input_handle(self.input_names[0]) self.input_handle = self.predictor.get_input_handle(self.input_names[0])
self.output_handle = self.predictor.get_output_handle(self.output_names[0]) self.output_handle = self.predictor.get_output_handle(self.output_names[0])
# 模型加载函数 def _get_device_id(self, places):
def load_model(self, modelpath, use_gpu, use_mkldnn, combined): try:
# 对运行位置进行配置 places = os.environ[places]
if use_gpu: id = int(places)
try: except:
int(os.environ.get('CUDA_VISIBLE_DEVICES')) id = -1
except Exception: return id
print(
'Error! Unable to use GPU. Please set the environment variables "CUDA_VISIBLE_DEVICES=GPU_id" to use GPU.'
)
use_gpu = False
# 模型加载函数
def load_model(self, modelpath, use_gpu, use_mkldnn, combined, use_device):
# 加载模型参数 # 加载模型参数
if combined: if combined:
model = os.path.join(modelpath, "__model__") model = os.path.join(modelpath, "__model__")
...@@ -38,13 +36,50 @@ class Model(): ...@@ -38,13 +36,50 @@ class Model():
else: else:
config = Config(modelpath) config = Config(modelpath)
# 设置参数 # 对运行位置进行配置
if use_gpu: if use_device is not None:
config.enable_use_gpu(100, 0) if use_device == "cpu":
if use_mkldnn:
config.enable_mkldnn()
elif use_device == "xpu":
xpu_id = self._get_device_id("XPU_VISIBLE_DEVICES")
if xpu_id != -1:
config.enable_xpu(100)
else:
print(
'Error! Unable to use XPU. Please set the environment variables "XPU_VISIBLE_DEVICES=XPU_id" to use XPU.'
)
elif use_device == "npu":
npu_id = self._get_device_id("FLAGS_selected_npus")
if npu_id != -1:
config.enable_npu(device_id=npu_id)
else:
print(
'Error! Unable to use NPU. Please set the environment variables "FLAGS_selected_npus=NPU_id" to use NPU.'
)
elif use_device == "gpu":
gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES")
if gpu_id != -1:
config.enable_use_gpu(100, gpu_id)
else:
print(
'Error! Unable to use GPU. Please set the environment variables "CUDA_VISIBLE_DEVICES=GPU_id" to use GPU.'
)
else:
raise Exception("Unsupported device: " + use_device)
else: else:
config.disable_gpu() if use_gpu:
if use_mkldnn: gpu_id = self._get_device_id("CUDA_VISIBLE_DEVICES")
config.enable_mkldnn() if gpu_id != -1:
config.enable_use_gpu(100, gpu_id)
else:
print(
'Error! Unable to use GPU. Please set the environment variables "CUDA_VISIBLE_DEVICES=GPU_id" to use GPU.'
)
else:
if use_mkldnn:
config.enable_mkldnn()
config.disable_glog_info() config.disable_glog_info()
config.switch_ir_optim(True) config.switch_ir_optim(True)
config.enable_memory_optim() config.enable_memory_optim()
......
...@@ -17,14 +17,15 @@ from animegan_v2_hayao_64.processor import base64_to_cv2, cv2_to_base64, Process ...@@ -17,14 +17,15 @@ from animegan_v2_hayao_64.processor import base64_to_cv2, cv2_to_base64, Process
) )
class Animegan_V2_Hayao_64(Module): class Animegan_V2_Hayao_64(Module):
# 初始化函数 # 初始化函数
def __init__(self, name=None, use_gpu=False): def __init__(self, name=None, use_gpu=False, use_device=None):
# 设置模型路径 # 设置模型路径
self.model_path = os.path.join(self.directory, "animegan_v2_hayao_64") self.model_path = os.path.join(self.directory, "animegan_v2_hayao_64")
# 加载模型 # 加载模型
self.model = Model(modelpath=self.model_path, use_gpu=use_gpu, use_mkldnn=False, combined=False) self.model = Model(
modelpath=self.model_path, use_gpu=use_gpu, use_mkldnn=False, combined=False, use_device=use_device)
# 关键点检测函数 # 风格转换函数
def style_transfer(self, def style_transfer(self,
images=None, images=None,
paths=None, paths=None,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册