diff --git a/PaddleCV/image_classification/models/autodl.py b/PaddleCV/image_classification/models/autodl.py index 915bc1631cc841194e2217186ee7849e92735b30..0c23d5b14cbd2b95046c06ea2e0604cf2aa8a07b 100644 --- a/PaddleCV/image_classification/models/autodl.py +++ b/PaddleCV/image_classification/models/autodl.py @@ -505,9 +505,9 @@ def StemConv1(input, C_out): return bn_a class NetworkImageNet(object): - def __init__(self, arch='DARTS_6M'): + def __init__(self, arch='DARTS_6M', class_dim=1000): self.params = train_parameters - self.class_num = 1000 + self.class_num = class_dim self.init_channel = 48 self._layers = 14 self._auxiliary = False