提交 a3bfb074 编写于 作者: C chenguowei01

Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleSeg into dygraph

...@@ -17,12 +17,12 @@ import os ...@@ -17,12 +17,12 @@ import os
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import paddleseg.env as segenv
from .dataset import Dataset from .dataset import Dataset
from paddleseg.utils.download import download_file_and_uncompress from paddleseg.utils.download import download_file_and_uncompress
from paddleseg.cvlibs import manager from paddleseg.cvlibs import manager
from paddleseg.transforms import Compose from paddleseg.transforms import Compose
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip" URL = "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip"
...@@ -61,8 +61,8 @@ class ADE20K(Dataset): ...@@ -61,8 +61,8 @@ class ADE20K(Dataset):
"`dataset_root` not set and auto download disabled.") "`dataset_root` not set and auto download disabled.")
self.dataset_root = download_file_and_uncompress( self.dataset_root = download_file_and_uncompress(
url=URL, url=URL,
savepath=DATA_HOME, savepath=segenv.DATA_HOME,
extrapath=DATA_HOME, extrapath=segenv.DATA_HOME,
extraname='ADEChallengeData2016') extraname='ADEChallengeData2016')
elif not os.path.exists(self.dataset_root): elif not os.path.exists(self.dataset_root):
raise Exception('there is not `dataset_root`: {}.'.format( raise Exception('there is not `dataset_root`: {}.'.format(
......
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
import os import os
import paddleseg.env as segenv
from .dataset import Dataset from .dataset import Dataset
from paddleseg.utils.download import download_file_and_uncompress from paddleseg.utils.download import download_file_and_uncompress
from paddleseg.cvlibs import manager from paddleseg.cvlibs import manager
from paddleseg.transforms import Compose from paddleseg.transforms import Compose
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip" URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip"
...@@ -49,7 +49,7 @@ class OpticDiscSeg(Dataset): ...@@ -49,7 +49,7 @@ class OpticDiscSeg(Dataset):
raise Exception( raise Exception(
"`data_root` not set and auto download disabled.") "`data_root` not set and auto download disabled.")
self.dataset_root = download_file_and_uncompress( self.dataset_root = download_file_and_uncompress(
url=URL, savepath=DATA_HOME, extrapath=DATA_HOME) url=URL, savepath=segenv.DATA_HOME, extrapath=segenv.DATA_HOME)
elif not os.path.exists(self.dataset_root): elif not os.path.exists(self.dataset_root):
raise Exception('there is not `dataset_root`: {}.'.format( raise Exception('there is not `dataset_root`: {}.'.format(
self.dataset_root)) self.dataset_root))
......
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
import os import os
import paddleseg.env as segenv
from .dataset import Dataset from .dataset import Dataset
from paddleseg.utils.download import download_file_and_uncompress from paddleseg.utils.download import download_file_and_uncompress
from paddleseg.cvlibs import manager from paddleseg.cvlibs import manager
from paddleseg.transforms import Compose from paddleseg.transforms import Compose
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar" URL = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
...@@ -59,8 +59,8 @@ class PascalVOC(Dataset): ...@@ -59,8 +59,8 @@ class PascalVOC(Dataset):
"`dataset_root` not set and auto download disabled.") "`dataset_root` not set and auto download disabled.")
self.dataset_root = download_file_and_uncompress( self.dataset_root = download_file_and_uncompress(
url=URL, url=URL,
savepath=DATA_HOME, savepath=segenv.DATA_HOME,
extrapath=DATA_HOME, extrapath=segenv.DATA_HOME,
extraname='VOCdevkit') extraname='VOCdevkit')
elif not os.path.exists(self.dataset_root): elif not os.path.exists(self.dataset_root):
raise Exception('there is not `dataset_root`: {}.'.format( raise Exception('there is not `dataset_root`: {}.'.format(
......
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
from paddleseg.utils import logger
def _get_user_home():
return os.path.expanduser('~')
def _get_seg_home():
if 'SEG_HOME' in os.environ:
home_path = os.environ['SEG_HOME']
if os.path.exists(home_path):
if os.path.isdir(home_path):
return home_path
else:
logger.warning('SEG_HOME {} is a file!'.format(home_path))
else:
return home_path
return os.path.join(_get_user_home(), '.paddleseg')
def _get_sub_home(directory):
home = os.path.join(_get_seg_home(), directory)
if not os.path.exists(home):
os.makedirs(home)
return home
USER_HOME = _get_user_home()
SEG_HOME = _get_seg_home()
DATA_HOME = _get_sub_home('dataset')
TMP_HOME = _get_sub_home('tmp')
PRETRAINED_MODEL_HOME = _get_sub_home('pretrained_model')
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddleseg.utils import utils
from paddleseg.cvlibs import manager, param_init
from paddleseg.models.common.layer_libs import ConvBNReLU
class PAM(nn.Layer):
"""Position attention module"""
def __init__(self, in_channels):
super(PAM, self).__init__()
mid_channels = in_channels // 8
self.query_conv = nn.Conv2d(in_channels, mid_channels, 1, 1)
self.key_conv = nn.Conv2d(in_channels, mid_channels, 1, 1)
self.value_conv = nn.Conv2d(in_channels, in_channels, 1, 1)
self.gamma = self.create_parameter(
shape=[1],
dtype='float32',
default_initializer=nn.initializer.Constant(0))
def forward(self, x):
n, _, h, w = x.shape
# query: n, h * w, c1
query = self.query_conv(x)
query = paddle.reshape(query, (n, -1, h * w))
query = paddle.transpose(query, (0, 2, 1))
# key: n, c1, h * w
key = self.key_conv(x)
key = paddle.reshape(key, (n, -1, h * w))
# sim: n, h * w, h * w
sim = paddle.bmm(query, key)
sim = F.softmax(sim, axis=-1)
value = self.value_conv(x)
value = paddle.reshape(value, (n, -1, h * w))
sim = paddle.transpose(sim, (0, 2, 1))
# feat: from (n, c2, h * w) -> (n, c2, h, w)
feat = paddle.bmm(value, sim)
feat = paddle.reshape(feat, (n, -1, h, w))
out = self.gamma * feat + x
return out
class CAM(nn.Layer):
"""Channel attention module"""
def __init__(self):
super(CAM, self).__init__()
self.gamma = self.create_parameter(
shape=[1],
dtype='float32',
default_initializer=nn.initializer.Constant(0))
def forward(self, x):
n, c, h, w = x.shape
# query: n, c, h * w
query = paddle.reshape(x, (n, c, h * w))
# key: n, h * w, c
key = paddle.reshape(x, (n, c, h * w))
key = paddle.transpose(key, (0, 2, 1))
# sim: n, c, c
sim = paddle.bmm(query, key)
# The danet author claims that this can avoid gradient divergence
sim = paddle.max(sim, axis=-1, keepdim=True).expand_as(sim) - sim
sim = F.softmax(sim, axis=-1)
# feat: from (n, c, h * w) to (n, c, h, w)
value = paddle.reshape(x, (n, c, h * w))
feat = paddle.bmm(sim, value)
feat = paddle.reshape(feat, (n, c, h, w))
out = self.gamma * feat + x
return out
class DAHead(nn.Layer):
"""
The Dual attention head.
Args:
num_classes(int): the unique number of target classes.
in_channels(tuple): the number of input channels.
"""
def __init__(self, num_classes, in_channels=None):
super(DAHead, self).__init__()
in_channels = in_channels[-1]
inter_channels = in_channels // 4
self.channel_conv = ConvBNReLU(
in_channels, inter_channels, 3, padding=1)
self.position_conv = ConvBNReLU(
in_channels, inter_channels, 3, padding=1)
self.pam = PAM(inter_channels)
self.cam = CAM()
self.conv1 = ConvBNReLU(inter_channels, inter_channels, 3, padding=1)
self.conv2 = ConvBNReLU(inter_channels, inter_channels, 3, padding=1)
self.aux_head_pam = nn.Sequential(
nn.Dropout2d(0.1), nn.Conv2d(inter_channels, num_classes, 1))
self.aux_head_cam = nn.Sequential(
nn.Dropout2d(0.1), nn.Conv2d(inter_channels, num_classes, 1))
self.cls_head = nn.Sequential(
nn.Dropout2d(0.1), nn.Conv2d(inter_channels, num_classes, 1))
self.init_weight()
def forward(self, x, label=None):
feats = x[-1]
channel_feats = self.channel_conv(feats)
channel_feats = self.cam(channel_feats)
channel_feats = self.conv1(channel_feats)
cam_head = self.aux_head_cam(channel_feats)
position_feats = self.position_conv(feats)
position_feats = self.pam(position_feats)
position_feats = self.conv2(position_feats)
pam_head = self.aux_head_pam(position_feats)
feats_sum = position_feats + channel_feats
cam_logit = self.aux_head_cam(channel_feats)
pam_logit = self.aux_head_cam(position_feats)
logit = self.cls_head(feats_sum)
return [logit, cam_logit, pam_logit]
def init_weight(self):
"""Initialize the parameters of model parts."""
for sublayer in self.sublayers():
if isinstance(sublayer, nn.Conv2d):
param_init.normal_init(sublayer.weight, scale=0.001)
elif isinstance(sublayer, nn.SyncBatchNorm):
param_init.constant_init(sublayer.weight, value=1)
param_init.constant_init(sublayer.bias, value=0)
@manager.MODELS.add_component
class DANet(nn.Layer):
"""
The DANet implementation based on PaddlePaddle.
The original article refers to
Fu, jun, et al. "Dual Attention Network for Scene Segmentation"
(https://arxiv.org/pdf/1809.02983.pdf)
Args:
num_classes(int): the unique number of target classes.
backbone(Paddle.nn.Layer): backbone network.
pretrained(str): the path or url of pretrained model. Default to None.
backbone_indices(tuple): values in the tuple indicate the indices of output of backbone.
Only the last indice is used.
"""
def __init__(self,
num_classes,
backbone,
pretrained=None,
backbone_indices=None):
super(DANet, self).__init__()
self.backbone = backbone
self.backbone_indices = backbone_indices
in_channels = [self.backbone.channels[i] for i in backbone_indices]
self.head = DAHead(num_classes=num_classes, in_channels=in_channels)
self.init_weight(pretrained)
def forward(self, x, label=None):
feats = self.backbone(x)
feats = [feats[i] for i in self.backbone_indices]
preds = self.head(feats, label)
preds = [F.resize_bilinear(pred, x.shape[2:]) for pred in preds]
return preds
def init_weight(self, pretrained=None):
"""
Initialize the parameters of model parts.
Args:
pretrained ([str], optional): the path of pretrained model.. Defaults to None.
"""
if pretrained is not None:
if os.path.exists(pretrained):
utils.load_pretrained_model(self, pretrained)
else:
raise Exception(
'Pretrained model is not found: {}'.format(pretrained))
...@@ -14,36 +14,41 @@ ...@@ -14,36 +14,41 @@
import os import os
import paddle.fluid as fluid import paddle
from paddle.fluid.dygraph import Sequential, Conv2D import paddle.nn as nn
import paddle.nn.functional as F
from paddleseg.cvlibs import manager
from paddleseg.models.common.layer_libs import ConvBnRelu
from paddleseg import utils from paddleseg import utils
from paddleseg.cvlibs import manager, param_init
from paddleseg.models.common.layer_libs import ConvBNReLU, AuxLayer
class SpatialGatherBlock(fluid.dygraph.Layer): class SpatialGatherBlock(nn.Layer):
"""Aggregation layer to compute the pixel-region representation"""
def forward(self, pixels, regions): def forward(self, pixels, regions):
n, c, h, w = pixels.shape n, c, h, w = pixels.shape
_, k, _, _ = regions.shape _, k, _, _ = regions.shape
# pixels: from (n, c, h, w) to (n, h*w, c) # pixels: from (n, c, h, w) to (n, h*w, c)
pixels = fluid.layers.reshape(pixels, (n, c, h * w)) pixels = paddle.reshape(pixels, (n, c, h * w))
pixels = fluid.layers.transpose(pixels, (0, 2, 1)) pixels = paddle.transpose(pixels, (0, 2, 1))
# regions: from (n, k, h, w) to (n, k, h*w) # regions: from (n, k, h, w) to (n, k, h*w)
regions = fluid.layers.reshape(regions, (n, k, h * w)) regions = paddle.reshape(regions, (n, k, h * w))
regions = fluid.layers.softmax(regions, axis=2) regions = F.softmax(regions, axis=2)
# feats: from (n, k, c) to (n, c, k, 1) # feats: from (n, k, c) to (n, c, k, 1)
feats = fluid.layers.matmul(regions, pixels) feats = paddle.bmm(regions, pixels)
feats = fluid.layers.transpose(feats, (0, 2, 1)) feats = paddle.transpose(feats, (0, 2, 1))
feats = fluid.layers.unsqueeze(feats, axes=[-1]) feats = paddle.unsqueeze(feats, axis=-1)
return feats return feats
class SpatialOCRModule(fluid.dygraph.Layer): class SpatialOCRModule(nn.Layer):
"""Aggregate the global object representation to update the representation for each pixel"""
def __init__(self, def __init__(self,
in_channels, in_channels,
key_channels, key_channels,
...@@ -53,163 +58,180 @@ class SpatialOCRModule(fluid.dygraph.Layer): ...@@ -53,163 +58,180 @@ class SpatialOCRModule(fluid.dygraph.Layer):
self.attention_block = ObjectAttentionBlock(in_channels, key_channels) self.attention_block = ObjectAttentionBlock(in_channels, key_channels)
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.conv1x1 = Conv2D(2 * in_channels, out_channels, 1) self.conv1x1 = nn.Sequential(
nn.Conv2d(2 * in_channels, out_channels, 1), nn.Dropout2d(0.1))
def forward(self, pixels, regions): def forward(self, pixels, regions):
context = self.attention_block(pixels, regions) context = self.attention_block(pixels, regions)
feats = fluid.layers.concat([context, pixels], axis=1) feats = paddle.concat([context, pixels], axis=1)
feats = self.conv1x1(feats) feats = self.conv1x1(feats)
feats = fluid.layers.dropout(feats, self.dropout_rate)
return feats return feats
class ObjectAttentionBlock(fluid.dygraph.Layer): class ObjectAttentionBlock(nn.Layer):
"""A self-attention module."""
def __init__(self, in_channels, key_channels): def __init__(self, in_channels, key_channels):
super(ObjectAttentionBlock, self).__init__() super(ObjectAttentionBlock, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.key_channels = key_channels self.key_channels = key_channels
self.f_pixel = Sequential( self.f_pixel = nn.Sequential(
ConvBnRelu(in_channels, key_channels, 1), ConvBNReLU(in_channels, key_channels, 1),
ConvBnRelu(key_channels, key_channels, 1)) ConvBNReLU(key_channels, key_channels, 1))
self.f_object = Sequential( self.f_object = nn.Sequential(
ConvBnRelu(in_channels, key_channels, 1), ConvBNReLU(in_channels, key_channels, 1),
ConvBnRelu(key_channels, key_channels, 1)) ConvBNReLU(key_channels, key_channels, 1))
self.f_down = ConvBnRelu(in_channels, key_channels, 1) self.f_down = ConvBNReLU(in_channels, key_channels, 1)
self.f_up = ConvBnRelu(key_channels, in_channels, 1) self.f_up = ConvBNReLU(key_channels, in_channels, 1)
def forward(self, x, proxy): def forward(self, x, proxy):
n, _, h, w = x.shape n, _, h, w = x.shape
# query : from (n, c1, h1, w1) to (n, h1*w1, key_channels) # query : from (n, c1, h1, w1) to (n, h1*w1, key_channels)
query = self.f_pixel(x) query = self.f_pixel(x)
query = fluid.layers.reshape(query, (n, self.key_channels, -1)) query = paddle.reshape(query, (n, self.key_channels, -1))
query = fluid.layers.transpose(query, (0, 2, 1)) query = paddle.transpose(query, (0, 2, 1))
# key : from (n, c2, h2, w2) to (n, key_channels, h2*w2) # key : from (n, c2, h2, w2) to (n, key_channels, h2*w2)
key = self.f_object(proxy) key = self.f_object(proxy)
key = fluid.layers.reshape(key, (n, self.key_channels, -1)) key = paddle.reshape(key, (n, self.key_channels, -1))
# value : from (n, c2, h2, w2) to (n, h2*w2, key_channels) # value : from (n, c2, h2, w2) to (n, h2*w2, key_channels)
value = self.f_down(proxy) value = self.f_down(proxy)
value = fluid.layers.reshape(value, (n, self.key_channels, -1)) value = paddle.reshape(value, (n, self.key_channels, -1))
value = fluid.layers.transpose(value, (0, 2, 1)) value = paddle.transpose(value, (0, 2, 1))
# sim_map (n, h1*w1, h2*w2) # sim_map (n, h1*w1, h2*w2)
sim_map = fluid.layers.matmul(query, key) sim_map = paddle.bmm(query, key)
sim_map = (self.key_channels**-.5) * sim_map sim_map = (self.key_channels**-.5) * sim_map
sim_map = fluid.layers.softmax(sim_map, axis=-1) sim_map = F.softmax(sim_map, axis=-1)
# context from (n, h1*w1, key_channels) to (n , out_channels, h1, w1) # context from (n, h1*w1, key_channels) to (n , out_channels, h1, w1)
context = fluid.layers.matmul(sim_map, value) context = paddle.bmm(sim_map, value)
context = fluid.layers.transpose(context, (0, 2, 1)) context = paddle.transpose(context, (0, 2, 1))
context = fluid.layers.reshape(context, (n, self.key_channels, h, w)) context = paddle.reshape(context, (n, self.key_channels, h, w))
context = self.f_up(context) context = self.f_up(context)
return context return context
@manager.MODELS.add_component class OCRHead(nn.Layer):
class OCRNet(fluid.dygraph.Layer): """
The Object contextual representation head.
Args:
num_classes(int): the unique number of target classes.
in_channels(tuple): the number of input channels.
ocr_mid_channels(int): the number of middle channels in OCRHead.
ocr_key_channels(int): the number of key channels in ObjectAttentionBlock.
"""
def __init__(self, def __init__(self,
num_classes, num_classes,
backbone,
model_pretrained=None,
in_channels=None, in_channels=None,
ocr_mid_channels=512, ocr_mid_channels=512,
ocr_key_channels=256, ocr_key_channels=256):
ignore_index=255): super(OCRHead, self).__init__()
super(OCRNet, self).__init__()
self.ignore_index = ignore_index
self.num_classes = num_classes self.num_classes = num_classes
self.EPS = 1e-5
self.backbone = backbone
self.spatial_gather = SpatialGatherBlock() self.spatial_gather = SpatialGatherBlock()
self.spatial_ocr = SpatialOCRModule(ocr_mid_channels, ocr_key_channels, self.spatial_ocr = SpatialOCRModule(ocr_mid_channels, ocr_key_channels,
ocr_mid_channels) ocr_mid_channels)
self.conv3x3_ocr = ConvBnRelu(
in_channels, ocr_mid_channels, 3, padding=1)
self.cls_head = Conv2D(ocr_mid_channels, self.num_classes, 1)
self.aux_head = Sequential( self.indices = [-2, -1] if len(in_channels) > 1 else [-1, -1]
ConvBnRelu(in_channels, in_channels, 3, padding=1),
Conv2D(in_channels, self.num_classes, 1))
self.init_weight(model_pretrained) self.conv3x3_ocr = ConvBNReLU(
in_channels[self.indices[1]], ocr_mid_channels, 3, padding=1)
self.cls_head = nn.Conv2d(ocr_mid_channels, self.num_classes, 1)
self.aux_head = AuxLayer(in_channels[self.indices[0]],
in_channels[self.indices[0]], self.num_classes)
self.init_weight()
def forward(self, x, label=None): def forward(self, x, label=None):
feats = self.backbone(x) feat_shallow, feat_deep = x[self.indices[0]], x[self.indices[1]]
soft_regions = self.aux_head(feats) soft_regions = self.aux_head(feat_shallow)
pixels = self.conv3x3_ocr(feats) pixels = self.conv3x3_ocr(feat_deep)
object_regions = self.spatial_gather(pixels, soft_regions) object_regions = self.spatial_gather(pixels, soft_regions)
ocr = self.spatial_ocr(pixels, object_regions) ocr = self.spatial_ocr(pixels, object_regions)
logit = self.cls_head(ocr) logit = self.cls_head(ocr)
logit = fluid.layers.resize_bilinear(logit, x.shape[2:]) return [logit, soft_regions]
if self.training: def init_weight(self):
soft_regions = fluid.layers.resize_bilinear(soft_regions, """Initialize the parameters of model parts."""
x.shape[2:]) for sublayer in self.sublayers():
cls_loss = self._get_loss(logit, label) if isinstance(sublayer, nn.Conv2d):
aux_loss = self._get_loss(soft_regions, label) param_init.normal_init(sublayer.weight, scale=0.001)
return cls_loss + 0.4 * aux_loss elif isinstance(sublayer, nn.SyncBatchNorm):
param_init.constant_init(sublayer.weight, value=1)
score_map = fluid.layers.softmax(logit, axis=1) param_init.constant_init(sublayer.bias, value=0)
score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1])
pred = fluid.layers.argmax(score_map, axis=3)
pred = fluid.layers.unsqueeze(pred, axes=[3]) @manager.MODELS.add_component
return pred, score_map class OCRNet(nn.Layer):
"""
def init_weight(self, pretrained_model=None): The OCRNet implementation based on PaddlePaddle.
The original article refers to
Yuan, Yuhui, et al. "Object-Contextual Representations for Semantic Segmentation"
(https://arxiv.org/pdf/1909.11065.pdf)
Args:
num_classes(int): the unique number of target classes.
backbone(Paddle.nn.Layer): backbone network.
pretrained(str): the path or url of pretrained model. Default to None.
backbone_indices(tuple): two values in the tuple indicate the indices of output of backbone.
the first index will be taken as a deep-supervision feature in auxiliary layer;
the second one will be taken as input of pixel representation.
ocr_mid_channels(int): the number of middle channels in OCRHead.
ocr_key_channels(int): the number of key channels in ObjectAttentionBlock.
"""
def __init__(self,
num_classes,
backbone,
pretrained=None,
backbone_indices=None,
ocr_mid_channels=512,
ocr_key_channels=256):
super(OCRNet, self).__init__()
self.backbone = backbone
self.backbone_indices = backbone_indices
in_channels = [self.backbone.channels[i] for i in backbone_indices]
self.head = OCRHead(
num_classes=num_classes,
in_channels=in_channels,
ocr_mid_channels=ocr_mid_channels,
ocr_key_channels=ocr_key_channels)
self.init_weight(pretrained)
def forward(self, x, label=None):
feats = self.backbone(x)
feats = [feats[i] for i in self.backbone_indices]
preds = self.head(feats, label)
preds = [F.resize_bilinear(pred, x.shape[2:]) for pred in preds]
return preds
def init_weight(self, pretrained=None):
""" """
Initialize the parameters of model parts. Initialize the parameters of model parts.
Args: Args:
pretrained_model ([str], optional): the path of pretrained model.. Defaults to None. pretrained ([str], optional): the path of pretrained model.. Defaults to None.
""" """
if pretrained_model is not None: if pretrained is not None:
if os.path.exists(pretrained_model): if os.path.exists(pretrained):
utils.load_pretrained_model(self, pretrained_model) utils.load_pretrained_model(self, pretrained)
else: else:
raise Exception('Pretrained model is not found: {}'.format( raise Exception(
pretrained_model)) 'Pretrained model is not found: {}'.format(pretrained))
def _get_loss(self, logit, label):
"""
compute forward loss of the model
Args:
logit (tensor): the logit of model output
label (tensor): ground truth
Returns:
avg_loss (tensor): forward loss
"""
logit = fluid.layers.transpose(logit, [0, 2, 3, 1])
label = fluid.layers.transpose(label, [0, 2, 3, 1])
mask = label != self.ignore_index
mask = fluid.layers.cast(mask, 'float32')
loss, probs = fluid.layers.softmax_with_cross_entropy(
logit,
label,
ignore_index=self.ignore_index,
return_softmax=True,
axis=-1)
loss = loss * mask
avg_loss = fluid.layers.mean(loss) / (
fluid.layers.mean(mask) + self.EPS)
label.stop_gradient = True
mask.stop_gradient = True
return avg_loss
...@@ -12,13 +12,28 @@ ...@@ -12,13 +12,28 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib
import os import os
import numpy as np import numpy as np
import math import math
import cv2 import cv2
import tempfile
import paddle.fluid as fluid import paddle.fluid as fluid
from urllib.parse import urlparse, unquote
from . import logger import filelock
import paddleseg.env as segenv
from paddleseg.utils import logger
from paddleseg.utils.download import download_file_and_uncompress
@contextlib.contextmanager
def generate_tempdir(directory: str = None, **kwargs):
'''Generate a temporary directory'''
directory = segenv.TMP_HOME if not directory else directory
with tempfile.TemporaryDirectory(dir=directory, **kwargs) as _dir:
yield _dir
def seconds_to_hms(seconds): def seconds_to_hms(seconds):
...@@ -32,6 +47,18 @@ def seconds_to_hms(seconds): ...@@ -32,6 +47,18 @@ def seconds_to_hms(seconds):
def load_pretrained_model(model, pretrained_model): def load_pretrained_model(model, pretrained_model):
if pretrained_model is not None: if pretrained_model is not None:
logger.info('Load pretrained model from {}'.format(pretrained_model)) logger.info('Load pretrained model from {}'.format(pretrained_model))
# download pretrained model from url
if urlparse(pretrained_model).netloc:
pretrained_model = unquote(pretrained_model)
savename = pretrained_model.split('/')[-1].split('.')[0]
with generate_tempdir() as _dir:
with filelock.FileLock(os.path.join(segenv.TMP_HOME, savename)):
pretrained_model = download_file_and_uncompress(
pretrained_model,
savepath=_dir,
extrapath=segenv.PRETRAINED_MODEL_HOME,
extraname=savename)
if os.path.exists(pretrained_model): if os.path.exists(pretrained_model):
ckpt_path = os.path.join(pretrained_model, 'model') ckpt_path = os.path.join(pretrained_model, 'model')
try: try:
......
...@@ -112,9 +112,10 @@ def main(args): ...@@ -112,9 +112,10 @@ def main(args):
val_dataset = cfg.val_dataset if args.do_eval else None val_dataset = cfg.val_dataset if args.do_eval else None
losses = cfg.loss losses = cfg.loss
print('---------------Config Information---------------') msg = '\n---------------Config Information---------------\n'
print(cfg) msg += str(cfg)
print('------------------------------------------------') msg += '------------------------------------------------'
logger.info(msg)
train( train(
cfg.model, cfg.model,
......
...@@ -19,7 +19,7 @@ from paddle.distributed import ParallelEnv ...@@ -19,7 +19,7 @@ from paddle.distributed import ParallelEnv
import paddleseg import paddleseg
from paddleseg.cvlibs import manager from paddleseg.cvlibs import manager
from paddleseg.utils import get_environ_info, Config from paddleseg.utils import get_environ_info, Config, logger
from paddleseg.core import evaluate from paddleseg.core import evaluate
...@@ -56,9 +56,10 @@ def main(args): ...@@ -56,9 +56,10 @@ def main(args):
'The verification dataset is not specified in the configuration file.' 'The verification dataset is not specified in the configuration file.'
) )
print('---------------Config Information---------------') msg = '\n---------------Config Information---------------\n'
print(cfg) msg += str(cfg)
print('------------------------------------------------') msg += '------------------------------------------------'
logger.info(msg)
evaluate( evaluate(
cfg.model, cfg.model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册