未验证 提交 a99e5064 编写于 作者: D duanyanhui 提交者: GitHub

update npu inference api (#779)

上级 fb1b085b
...@@ -118,7 +118,7 @@ def create_predictor(model_path, ...@@ -118,7 +118,7 @@ def create_predictor(model_path,
elif device == "cpu": elif device == "cpu":
config.disable_gpu() config.disable_gpu()
elif device == "npu": elif device == "npu":
config.enable_npu() config.enable_custom_device('npu')
elif device == "xpu": elif device == "xpu":
config.enable_xpu() config.enable_xpu()
else: else:
...@@ -174,11 +174,10 @@ def main(): ...@@ -174,11 +174,10 @@ def main():
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
cfg = get_config(args.config_file, args.opt, show=True) cfg = get_config(args.config_file, args.opt, show=True)
predictor, config = create_predictor(args.model_path, args.device, args.run_mode, predictor, config = create_predictor(
args.batch_size, args.min_subgraph_size, args.model_path, args.device, args.run_mode, args.batch_size,
args.use_dynamic_shape, args.trt_min_shape, args.min_subgraph_size, args.use_dynamic_shape, args.trt_min_shape,
args.trt_max_shape, args.trt_opt_shape, args.trt_max_shape, args.trt_opt_shape, args.trt_calib_mode)
args.trt_calib_mode)
input_handles = [ input_handles = [
predictor.get_input_handle(name) predictor.get_input_handle(name)
for name in predictor.get_input_names() for name in predictor.get_input_names()
...@@ -225,7 +224,7 @@ def main(): ...@@ -225,7 +224,7 @@ def main():
elif model_type == "cyclegan": elif model_type == "cyclegan":
import auto_log import auto_log
logger = get_logger(name='ppgan') logger = get_logger(name='ppgan')
size = data['A'].shape size = data['A'].shape
pid = os.getpid() pid = os.getpid()
auto_logger = auto_log.AutoLogger( auto_logger = auto_log.AutoLogger(
...@@ -233,7 +232,7 @@ def main(): ...@@ -233,7 +232,7 @@ def main():
model_precision=args.run_mode, model_precision=args.run_mode,
batch_size=args.batch_size, batch_size=args.batch_size,
data_shape=size, data_shape=size,
save_path=args.output_path+'auto_log.lpg', save_path=args.output_path + 'auto_log.lpg',
inference_config=config, inference_config=config,
pids=pid, pids=pid,
process_name=None, process_name=None,
...@@ -254,7 +253,11 @@ def main(): ...@@ -254,7 +253,11 @@ def main():
save_image( save_image(
image_numpy, image_numpy,
os.path.join(args.output_path, "cyclegan/{}.png".format(i))) os.path.join(args.output_path, "cyclegan/{}.png".format(i)))
logger.info("Inference succeeded! The inference result has been saved in {}".format(os.path.join(args.output_path, "cyclegan/{}.png".format(i)))) logger.info(
"Inference succeeded! The inference result has been saved in {}"
.format(
os.path.join(args.output_path,
"cyclegan/{}.png".format(i))))
auto_logger.times.end(stamp=True) auto_logger.times.end(stamp=True)
auto_logger.report() auto_logger.report()
metric_file = os.path.join(args.output_path, "cyclegan/metric.txt") metric_file = os.path.join(args.output_path, "cyclegan/metric.txt")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册