提交 1874ec9b 编写于 作者: W wuyefeilin 提交者: Zeyu Chen

Rm importlib (#124)

* remove importlib

* remove importlib
上级 d4df83f4
......@@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import struct
import importlib
import paddle.fluid as fluid
import numpy as np
......@@ -26,6 +24,7 @@ from utils.config import cfg
from loss import multi_softmax_with_loss
from loss import multi_dice_loss
from loss import multi_bce_loss
from models.modeling import deeplab, unet, icnet, pspnet, hrnet
class ModelPhase(object):
......@@ -70,40 +69,23 @@ class ModelPhase(object):
return False
def map_model_name(model_name):
name_dict = {
"unet": "unet.unet",
"deeplabv3p": "deeplab.deeplabv3p",
"icnet": "icnet.icnet",
"pspnet": "pspnet.pspnet",
"hrnet": "hrnet.hrnet"
}
if model_name in name_dict.keys():
return name_dict[model_name]
def seg_model(image, class_num):
model_name = cfg.MODEL.MODEL_NAME
if model_name == 'unet':
logits = unet.unet(image, class_num)
elif model_name == 'deeplabv3p':
logits = deeplab.deeplabv3p(image, class_num)
elif model_name == 'icnet':
logits = icnet.icnet(image, class_num)
elif model_name == 'pspnet':
logits = pspnet.pspnet(image, class_num)
elif model_name == 'hrnet':
logits = hrnet.hrnet(image, class_num)
else:
raise Exception(
"unknow model name, only support unet, deeplabv3p, icnet")
def get_func(func_name):
"""Helper to return a function object by name. func_name must identify a
function in this module or the path to a function relative to the base
'modeling' module.
"""
if func_name == '':
return None
try:
parts = func_name.split('.')
# Refers to a function in this module
if len(parts) == 1:
return globals()[parts[0]]
# Otherwise, assume we're referencing a module under modeling
module_name = 'models.' + '.'.join(parts[:-1])
module = importlib.import_module(module_name)
return getattr(module, parts[-1])
except Exception:
print('Failed to find function: {}'.format(func_name))
return module
"unknow model name, only support unet, deeplabv3p, icnet, pspnet, hrnet"
)
return logits
def softmax(logit):
......@@ -124,6 +106,7 @@ def sigmoid_to_softmax(logit):
logit = fluid.layers.transpose(logit, [0, 3, 1, 2])
return logit
def export_preprocess(image):
"""导出模型的预处理流程"""
......@@ -135,10 +118,7 @@ def export_preprocess(image):
h_fix = cfg.AUG.FIX_RESIZE_SIZE[1]
w_fix = cfg.AUG.FIX_RESIZE_SIZE[0]
image = fluid.layers.resize_bilinear(
image,
out_shape=[h_fix, w_fix],
align_corners=False,
align_mode=0)
image, out_shape=[h_fix, w_fix], align_corners=False, align_mode=0)
elif cfg.AUG.AUG_METHOD == 'rangescaling':
size = cfg.AUG.INF_RESIZE_VALUE
value = fluid.layers.reduce_max(origin_shape)
......@@ -160,8 +140,7 @@ def export_preprocess(image):
right = pad_target[1] - valid_shape[1]
paddings = fluid.layers.concat([up, down, left, right])
paddings = fluid.layers.cast(paddings, 'int32')
image = fluid.layers.pad2d(
image, paddings=paddings, pad_value=127.5)
image = fluid.layers.pad2d(image, paddings=paddings, pad_value=127.5)
# normalize
mean = np.array(cfg.MEAN).reshape(1, len(cfg.MEAN), 1, 1)
......@@ -199,7 +178,8 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
shape=[-1, -1, -1, cfg.DATASET.DATA_DIM],
dtype='float32',
append_batch_size=False)
image, valid_shape, origin_shape = export_preprocess(origin_image)
image, valid_shape, origin_shape = export_preprocess(
origin_image)
else:
image = fluid.layers.data(
......@@ -217,9 +197,6 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
iterable=False,
use_double_buffer=True)
model_name = map_model_name(cfg.MODEL.MODEL_NAME)
model_func = get_func("modeling." + model_name)
loss_type = cfg.SOLVER.LOSS
if not isinstance(loss_type, list):
loss_type = list(loss_type)
......@@ -238,7 +215,7 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
raise Exception(
"softmax loss can not combine with dice loss or bce loss"
)
logits = model_func(image, class_num)
logits = seg_model(image, class_num)
# 根据选择的loss函数计算相应的损失函数
if ModelPhase.is_train(phase) or ModelPhase.is_eval(phase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册