diff --git a/paddlex/cv/nets/hrnet.py b/paddlex/cv/nets/hrnet.py index 19f9cb336bce66a7dc68d65e316440adf46857e4..a7934d385d4a53fd936410e37d3896fe21cb17ee 100644 --- a/paddlex/cv/nets/hrnet.py +++ b/paddlex/cv/nets/hrnet.py @@ -71,7 +71,7 @@ class HRNet(object): self.end_points = [] return - def net(self, input, class_dim=1000): + def net(self, input): width = self.width channels_2, channels_3, channels_4 = self.channels[width] num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3 @@ -125,7 +125,7 @@ class HRNet(object): stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) out = fluid.layers.fc( input=pool, - size=class_dim, + size=self.num_classes, param_attr=ParamAttr( name='fc_weights', initializer=fluid.initializer.Uniform(-stdv, stdv)),