From a99e5064fda751b9ad00c309a889cb5c3cebf0f2 Mon Sep 17 00:00:00 2001 From: duanyanhui <45005871+YanhuiDua@users.noreply.github.com> Date: Tue, 18 Apr 2023 15:42:44 +0800 Subject: [PATCH] update npu inference api (#779) --- tools/inference.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/tools/inference.py b/tools/inference.py index abaa469..0c77982 100644 --- a/tools/inference.py +++ b/tools/inference.py @@ -118,7 +118,7 @@ def create_predictor(model_path, elif device == "cpu": config.disable_gpu() elif device == "npu": - config.enable_npu() + config.enable_custom_device('npu') elif device == "xpu": config.enable_xpu() else: @@ -174,11 +174,10 @@ def main(): random.seed(args.seed) np.random.seed(args.seed) cfg = get_config(args.config_file, args.opt, show=True) - predictor, config = create_predictor(args.model_path, args.device, args.run_mode, - args.batch_size, args.min_subgraph_size, - args.use_dynamic_shape, args.trt_min_shape, - args.trt_max_shape, args.trt_opt_shape, - args.trt_calib_mode) + predictor, config = create_predictor( + args.model_path, args.device, args.run_mode, args.batch_size, + args.min_subgraph_size, args.use_dynamic_shape, args.trt_min_shape, + args.trt_max_shape, args.trt_opt_shape, args.trt_calib_mode) input_handles = [ predictor.get_input_handle(name) for name in predictor.get_input_names() @@ -225,7 +224,7 @@ def main(): elif model_type == "cyclegan": import auto_log logger = get_logger(name='ppgan') - + size = data['A'].shape pid = os.getpid() auto_logger = auto_log.AutoLogger( @@ -233,7 +232,7 @@ def main(): model_precision=args.run_mode, batch_size=args.batch_size, data_shape=size, - save_path=args.output_path+'auto_log.lpg', + save_path=args.output_path + 'auto_log.lpg', inference_config=config, pids=pid, process_name=None, @@ -254,7 +253,11 @@ def main(): save_image( image_numpy, os.path.join(args.output_path, "cyclegan/{}.png".format(i))) - logger.info("Inference succeeded! The inference result has been saved in {}".format(os.path.join(args.output_path, "cyclegan/{}.png".format(i)))) + logger.info( + "Inference succeeded! The inference result has been saved in {}" + .format( + os.path.join(args.output_path, + "cyclegan/{}.png".format(i)))) auto_logger.times.end(stamp=True) auto_logger.report() metric_file = os.path.join(args.output_path, "cyclegan/metric.txt") -- GitLab