diff --git a/dygraph/models/hrnet.py b/dygraph/models/hrnet.py index 7707cdb403d4be3499e3bfebdabdc3ac79ad0910..fac8a929be40acce2d801c3cdbbe89bb634bead3 100644 --- a/dygraph/models/hrnet.py +++ b/dygraph/models/hrnet.py @@ -82,7 +82,9 @@ class HRNet(fluid.dygraph.Layer): name="layer1_2") self.la1 = Layer1( - num_channels=self.stage1_num_channels[0], + num_channels=64, + num_blocks=self.stage1_num_blocks[0], + num_filters=self.stage1_num_channels[0], has_se=has_se, name="layer2") @@ -228,17 +230,22 @@ class ConvBNLayer(fluid.dygraph.Layer): class Layer1(fluid.dygraph.Layer): - def __init__(self, num_channels, has_se=False, name=None): + def __init__(self, + num_channels, + num_filters, + num_blocks, + has_se=False, + name=None): super(Layer1, self).__init__() self.bottleneck_block_list = [] - for i in range(4): + for i in range(num_blocks): bottleneck_block = self.add_sublayer( "bb_{}_{}".format(name, i + 1), BottleneckBlock( - num_channels=num_channels if i == 0 else 256, - num_filters=64, + num_channels=num_channels if i == 0 else num_filters * 4, + num_filters=num_filters, has_se=has_se, stride=1, downsample=True if i == 0 else False,