未验证 提交 2b69eaea 编写于 作者: J Jason 提交者: GitHub

Merge pull request #141 from PaddlePaddle/develop_hrnet

fixed input shape for hrnet
......@@ -77,6 +77,7 @@ class HRNet(DeepLabv3p):
self.class_weight = class_weight
self.ignore_index = ignore_index
self.labels = None
self.fixed_input_shape = None
def build_net(self, mode='train'):
model = paddlex.cv.nets.segmentation.HRNet(
......@@ -86,7 +87,8 @@ class HRNet(DeepLabv3p):
use_bce_loss=self.use_bce_loss,
use_dice_loss=self.use_dice_loss,
class_weight=self.class_weight,
ignore_index=self.ignore_index)
ignore_index=self.ignore_index,
fixed_input_shape=self.fixed_input_shape)
inputs = model.generate_inputs()
model_out = model.build_net(inputs)
outputs = OrderedDict()
......@@ -170,6 +172,6 @@ class HRNet(DeepLabv3p):
return super(HRNet, self).train(
num_epochs, train_dataset, train_batch_size, eval_dataset,
save_interval_epochs, log_interval_steps, save_dir,
pretrain_weights, optimizer, learning_rate, lr_decay_power,
use_vdl, sensitivities_file, eval_metric_loss, early_stop,
pretrain_weights, optimizer, learning_rate, lr_decay_power, use_vdl,
sensitivities_file, eval_metric_loss, early_stop,
early_stop_patience, resume_checkpoint)
......@@ -38,7 +38,8 @@ class HRNet(object):
use_bce_loss=False,
use_dice_loss=False,
class_weight=None,
ignore_index=255):
ignore_index=255,
fixed_input_shape=None):
# dice_loss或bce_loss只适用两类分割中
if num_classes > 2 and (use_bce_loss or use_dice_loss):
raise ValueError(
......@@ -66,6 +67,7 @@ class HRNet(object):
self.use_dice_loss = use_dice_loss
self.class_weight = class_weight
self.ignore_index = ignore_index
self.fixed_input_shape = fixed_input_shape
self.backbone = paddlex.cv.nets.hrnet.HRNet(
width=width, feature_maps="stage4")
......@@ -131,6 +133,14 @@ class HRNet(object):
def generate_inputs(self):
inputs = OrderedDict()
if self.fixed_input_shape is not None:
input_shape = [
None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
]
inputs['image'] = fluid.data(
dtype='float32', shape=input_shape, name='image')
else:
inputs['image'] = fluid.data(
dtype='float32', shape=[None, 3, None, None], name='image')
if self.mode == 'train':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册