diff --git a/demo/dygraph/quant/train.py b/demo/dygraph/quant/train.py
index 25c0a3d195e5129fe6b672b535d1ddcd5ac67c49..a9853653e4905bc2e82982a9d4e1ccf97e6cd818 100644
--- a/demo/dygraph/quant/train.py
+++ b/demo/dygraph/quant/train.py
@@ -117,9 +117,9 @@ def compress(args):
 
     pretrain = True if args.data == "imagenet" else False
     if args.model == "mobilenet_v1":
-        net = mobilenet_v1(pretrained=pretrain)
+        net = mobilenet_v1(pretrained=pretrain, num_classes=class_dim)
     elif args.model == "mobilenet_v3":
-        net = MobileNetV3_large_x1_0()
+        net = MobileNetV3_large_x1_0(class_dim=class_dim)
         if pretrain:
             load_dygraph_pretrain(net, args.pretrained_model, True)
     else: