diff --git a/paddlex/cv/models/hrnet.py b/paddlex/cv/models/hrnet.py index 0eec2be561911fd18bed97eef3e49b897c60510a..304cf0b1568d473df615e12ffbaa96e9d681af20 100644 --- a/paddlex/cv/models/hrnet.py +++ b/paddlex/cv/models/hrnet.py @@ -77,6 +77,7 @@ class HRNet(DeepLabv3p): self.class_weight = class_weight self.ignore_index = ignore_index self.labels = None + self.fixed_input_shape = None def build_net(self, mode='train'): model = paddlex.cv.nets.segmentation.HRNet( @@ -86,7 +87,8 @@ class HRNet(DeepLabv3p): use_bce_loss=self.use_bce_loss, use_dice_loss=self.use_dice_loss, class_weight=self.class_weight, - ignore_index=self.ignore_index) + ignore_index=self.ignore_index, + fixed_input_shape=self.fixed_input_shape) inputs = model.generate_inputs() model_out = model.build_net(inputs) outputs = OrderedDict() @@ -170,6 +172,6 @@ class HRNet(DeepLabv3p): return super(HRNet, self).train( num_epochs, train_dataset, train_batch_size, eval_dataset, save_interval_epochs, log_interval_steps, save_dir, - pretrain_weights, optimizer, learning_rate, lr_decay_power, - use_vdl, sensitivities_file, eval_metric_loss, early_stop, + pretrain_weights, optimizer, learning_rate, lr_decay_power, use_vdl, + sensitivities_file, eval_metric_loss, early_stop, early_stop_patience, resume_checkpoint) diff --git a/paddlex/cv/nets/segmentation/hrnet.py b/paddlex/cv/nets/segmentation/hrnet.py index 65f3bfbc2514a325da28c022b2ae2c434d7a2eb4..b0bf10d5dd172851b12234a0e07a059f58b82773 100644 --- a/paddlex/cv/nets/segmentation/hrnet.py +++ b/paddlex/cv/nets/segmentation/hrnet.py @@ -38,7 +38,8 @@ class HRNet(object): use_bce_loss=False, use_dice_loss=False, class_weight=None, - ignore_index=255): + ignore_index=255, + fixed_input_shape=None): # dice_loss或bce_loss只适用两类分割中 if num_classes > 2 and (use_bce_loss or use_dice_loss): raise ValueError( @@ -66,6 +67,7 @@ class HRNet(object): self.use_dice_loss = use_dice_loss self.class_weight = class_weight self.ignore_index = ignore_index + self.fixed_input_shape = fixed_input_shape self.backbone = paddlex.cv.nets.hrnet.HRNet( width=width, feature_maps="stage4") @@ -131,8 +133,16 @@ class HRNet(object): def generate_inputs(self): inputs = OrderedDict() - inputs['image'] = fluid.data( - dtype='float32', shape=[None, 3, None, None], name='image') + + if self.fixed_input_shape is not None: + input_shape = [ + None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0] + ] + inputs['image'] = fluid.data( + dtype='float32', shape=input_shape, name='image') + else: + inputs['image'] = fluid.data( + dtype='float32', shape=[None, 3, None, None], name='image') if self.mode == 'train': inputs['label'] = fluid.data( dtype='int32', shape=[None, 1, None, None], name='label')