diff --git a/demo/nas/rl_nas_mobilenetv2.py b/demo/nas/rl_nas_mobilenetv2.py index 2fda58317cf36b024cd692afa291ebeacfaaa922..abf23fb9481f72c29f1ddffd1b4012961083b59f 100644 --- a/demo/nas/rl_nas_mobilenetv2.py +++ b/demo/nas/rl_nas_mobilenetv2.py @@ -182,7 +182,7 @@ if __name__ == '__main__': parser.add_argument( '--batch_size', type=int, default=256, help='batch size.') parser.add_argument( - '--class_dim', type=int, default=1000, help='classify number.') + '--class_dim', type=int, default=10, help='classify number.') parser.add_argument( '--data', type=str, diff --git a/demo/nas/sa_nas_mobilenetv2.py b/demo/nas/sa_nas_mobilenetv2.py index f1ebd81f54bf48365c10691cf09b2c0bc53b8963..9ef26bc5d1d75a10cd8ebf7183bb70b8dbd2b354 100644 --- a/demo/nas/sa_nas_mobilenetv2.py +++ b/demo/nas/sa_nas_mobilenetv2.py @@ -271,7 +271,7 @@ if __name__ == '__main__': parser.add_argument( '--batch_size', type=int, default=256, help='batch size.') parser.add_argument( - '--class_dim', type=int, default=1000, help='classify number.') + '--class_dim', type=int, default=10, help='classify number.') parser.add_argument( '--data', type=str,