diff --git a/demo/dygraph/quant/train.py b/demo/dygraph/quant/train.py index a9853653e4905bc2e82982a9d4e1ccf97e6cd818..e46f9a05564b2276dc0e7bd30b97c984b5cec64f 100644 --- a/demo/dygraph/quant/train.py +++ b/demo/dygraph/quant/train.py @@ -44,7 +44,7 @@ _logger = get_logger(__name__, level=logging.INFO) parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) # yapf: disable -add_arg('batch_size', int, 256, "Single Card Minibatch size.") +add_arg('batch_size', int, 128, "Single Card Minibatch size.") add_arg('use_gpu', bool, True, "Whether to use GPU or not.") add_arg('model', str, "mobilenet_v3", "The target model.") add_arg('pretrained_model', str, "MobileNetV3_large_x1_0_ssld_pretrained", "Whether to use pretrained model.")