From 311f9a059d3e70c5ab1504abb08d9f1d9bc62afb Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Fri, 13 Dec 2019 15:30:04 +0800 Subject: [PATCH] update export process --- pdseg/models/model_builder.py | 18 ++++++++---------- pdseg/utils/collect.py | 6 ++++++ 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/pdseg/models/model_builder.py b/pdseg/models/model_builder.py index e0b19510..fc1178af 100644 --- a/pdseg/models/model_builder.py +++ b/pdseg/models/model_builder.py @@ -126,24 +126,20 @@ def sigmoid_to_softmax(logit): def export_preprocess(image): """导出模型的预处理流程""" - width = cfg.EVAL_CROP_SIZE[0] - height = cfg.EVAL_CROP_SIZE[1] image = fluid.layers.transpose(image, [0, 3, 1, 2]) origin_shape = fluid.layers.shape(image)[-2:] # 不同AUG_METHOD方法的resize if cfg.AUG.AUG_METHOD == 'unpadding': - h = cfg.AUG.FIX_RESIZE_SIZE[1] - w = cfg.AUG.FIX_RESIZE_SIZE[0] + h_fix = cfg.AUG.FIX_RESIZE_SIZE[1] + w_fix = cfg.AUG.FIX_RESIZE_SIZE[0] image = fluid.layers.resize_bilinear( image, - out_shape=[h, w], + out_shape=[h_fix, w_fix], align_corners=False, align_mode=0) - if cfg.AUG.AUG_METHOD == 'stepscaling': - pass - if cfg.AUG.AUG_METHOD == 'rangescaling': + elif cfg.AUG.AUG_METHOD == 'rangescaling': size = cfg.AUG.INF_RESIZE_VALUE value = fluid.layers.reduce_max(origin_shape) scale = float(size) / value.astype('float32') @@ -153,7 +149,9 @@ def export_preprocess(image): # 存储resize后图像shape valid_shape = fluid.layers.shape(image)[-2:] - # padding 到eval_crop_size大小 + # padding到eval_crop_size大小 + width = cfg.EVAL_CROP_SIZE[0] + height = cfg.EVAL_CROP_SIZE[1] pad_target = fluid.layers.assign( np.array([height, width]).astype('float32')) up = fluid.layers.assign(np.array([0]).astype('float32')) @@ -171,7 +169,7 @@ def export_preprocess(image): std = np.array(cfg.STD).reshape(1, len(cfg.STD), 1, 1) std = fluid.layers.assign(std.astype('float32')) image = (image / 255 - mean) / std - # 很有必要,使后面的网络能通过image.shape获取特征图的shape + # 使后面的网络能通过类似image.shape获取特征图的shape image = fluid.layers.reshape( image, shape=[-1, cfg.DATASET.DATA_DIM, height, width]) return image, valid_shape, origin_shape diff --git a/pdseg/utils/collect.py b/pdseg/utils/collect.py index 78baf63f..3321ec12 100644 --- a/pdseg/utils/collect.py +++ b/pdseg/utils/collect.py @@ -122,6 +122,12 @@ class SegConfig(dict): len(self.MODEL.MULTI_LOSS_WEIGHT) != 3: self.MODEL.MULTI_LOSS_WEIGHT = [1.0, 0.4, 0.16] + if self.AUG.AUG_METHOD not in ['unpadding', 'stepscaling', 'rangescaling']: + raise ValueError( + 'AUG.AUG_METHOD config error, only support `unpadding`, `unpadding` and `rangescaling`' + ) + + def update_from_list(self, config_list): if len(config_list) % 2 != 0: raise ValueError( -- GitLab