diff --git a/demo/dygraph/quant/train.py b/demo/dygraph/quant/train.py index 25c0a3d195e5129fe6b672b535d1ddcd5ac67c49..a9853653e4905bc2e82982a9d4e1ccf97e6cd818 100644 --- a/demo/dygraph/quant/train.py +++ b/demo/dygraph/quant/train.py @@ -117,9 +117,9 @@ def compress(args): pretrain = True if args.data == "imagenet" else False if args.model == "mobilenet_v1": - net = mobilenet_v1(pretrained=pretrain) + net = mobilenet_v1(pretrained=pretrain, num_classes=class_dim) elif args.model == "mobilenet_v3": - net = MobileNetV3_large_x1_0() + net = MobileNetV3_large_x1_0(class_dim=class_dim) if pretrain: load_dygraph_pretrain(net, args.pretrained_model, True) else: