diff --git a/demo/prune/train.py b/demo/prune/train.py index 28c907c98c58a2de1090b9e28d78081de5a0c7f7..23379a326f6c6a93ca4e45e58de6e30263ba7a34 100644 --- a/demo/prune/train.py +++ b/demo/prune/train.py @@ -208,7 +208,7 @@ def compress(args): val_program, fluid.global_scope(), params=params, - ratios=[FLAGS.pruned_ratio] * len(params), + ratios=[args.pruned_ratio] * len(params), place=place, only_graph=True) @@ -216,7 +216,7 @@ def compress(args): fluid.default_main_program(), fluid.global_scope(), params=params, - ratios=[FLAGS.pruned_ratio] * len(params), + ratios=[args.pruned_ratio] * len(params), place=place) _logger.info("FLOPs after pruning: {}".format(flops(pruned_program))) for i in range(args.num_epochs): diff --git a/requirements.txt b/requirements.txt index 8e3fcf66561d1965a047b8debd38a543373f534f..e00ec03503dc9034444ecbf0458b7716f0a06f8c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ #paddlepaddle == 1.6.0rc0 +tqdm