提交 a3bfb074 编写于 作者: C chenguowei01

Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleSeg into dygraph

......@@ -17,12 +17,12 @@ import os
import numpy as np
from PIL import Image
import paddleseg.env as segenv
from .dataset import Dataset
from paddleseg.utils.download import download_file_and_uncompress
from paddleseg.cvlibs import manager
from paddleseg.transforms import Compose
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip"
......@@ -61,8 +61,8 @@ class ADE20K(Dataset):
"`dataset_root` not set and auto download disabled.")
self.dataset_root = download_file_and_uncompress(
url=URL,
savepath=DATA_HOME,
extrapath=DATA_HOME,
savepath=segenv.DATA_HOME,
extrapath=segenv.DATA_HOME,
extraname='ADEChallengeData2016')
elif not os.path.exists(self.dataset_root):
raise Exception('there is not `dataset_root`: {}.'.format(
......
......@@ -14,12 +14,12 @@
import os
import paddleseg.env as segenv
from .dataset import Dataset
from paddleseg.utils.download import download_file_and_uncompress
from paddleseg.cvlibs import manager
from paddleseg.transforms import Compose
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip"
......@@ -49,7 +49,7 @@ class OpticDiscSeg(Dataset):
raise Exception(
"`data_root` not set and auto download disabled.")
self.dataset_root = download_file_and_uncompress(
url=URL, savepath=DATA_HOME, extrapath=DATA_HOME)
url=URL, savepath=segenv.DATA_HOME, extrapath=segenv.DATA_HOME)
elif not os.path.exists(self.dataset_root):
raise Exception('there is not `dataset_root`: {}.'.format(
self.dataset_root))
......
......@@ -14,12 +14,12 @@
import os
import paddleseg.env as segenv
from .dataset import Dataset
from paddleseg.utils.download import download_file_and_uncompress
from paddleseg.cvlibs import manager
from paddleseg.transforms import Compose
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
......@@ -59,8 +59,8 @@ class PascalVOC(Dataset):
"`dataset_root` not set and auto download disabled.")
self.dataset_root = download_file_and_uncompress(
url=URL,
savepath=DATA_HOME,
extrapath=DATA_HOME,
savepath=segenv.DATA_HOME,
extrapath=segenv.DATA_HOME,
extraname='VOCdevkit')
elif not os.path.exists(self.dataset_root):
raise Exception('there is not `dataset_root`: {}.'.format(
......
# coding:utf-8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
from paddleseg.utils import logger
def _get_user_home():
return os.path.expanduser('~')
def _get_seg_home():
if 'SEG_HOME' in os.environ:
home_path = os.environ['SEG_HOME']
if os.path.exists(home_path):
if os.path.isdir(home_path):
return home_path
else:
logger.warning('SEG_HOME {} is a file!'.format(home_path))
else:
return home_path
return os.path.join(_get_user_home(), '.paddleseg')
def _get_sub_home(directory):
home = os.path.join(_get_seg_home(), directory)
if not os.path.exists(home):
os.makedirs(home)
return home
USER_HOME = _get_user_home()
SEG_HOME = _get_seg_home()
DATA_HOME = _get_sub_home('dataset')
TMP_HOME = _get_sub_home('tmp')
PRETRAINED_MODEL_HOME = _get_sub_home('pretrained_model')
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddleseg.utils import utils
from paddleseg.cvlibs import manager, param_init
from paddleseg.models.common.layer_libs import ConvBNReLU
class PAM(nn.Layer):
"""Position attention module"""
def __init__(self, in_channels):
super(PAM, self).__init__()
mid_channels = in_channels // 8
self.query_conv = nn.Conv2d(in_channels, mid_channels, 1, 1)
self.key_conv = nn.Conv2d(in_channels, mid_channels, 1, 1)
self.value_conv = nn.Conv2d(in_channels, in_channels, 1, 1)
self.gamma = self.create_parameter(
shape=[1],
dtype='float32',
default_initializer=nn.initializer.Constant(0))
def forward(self, x):
n, _, h, w = x.shape
# query: n, h * w, c1
query = self.query_conv(x)
query = paddle.reshape(query, (n, -1, h * w))
query = paddle.transpose(query, (0, 2, 1))
# key: n, c1, h * w
key = self.key_conv(x)
key = paddle.reshape(key, (n, -1, h * w))
# sim: n, h * w, h * w
sim = paddle.bmm(query, key)
sim = F.softmax(sim, axis=-1)
value = self.value_conv(x)
value = paddle.reshape(value, (n, -1, h * w))
sim = paddle.transpose(sim, (0, 2, 1))
# feat: from (n, c2, h * w) -> (n, c2, h, w)
feat = paddle.bmm(value, sim)
feat = paddle.reshape(feat, (n, -1, h, w))
out = self.gamma * feat + x
return out
class CAM(nn.Layer):
"""Channel attention module"""
def __init__(self):
super(CAM, self).__init__()
self.gamma = self.create_parameter(
shape=[1],
dtype='float32',
default_initializer=nn.initializer.Constant(0))
def forward(self, x):
n, c, h, w = x.shape
# query: n, c, h * w
query = paddle.reshape(x, (n, c, h * w))
# key: n, h * w, c
key = paddle.reshape(x, (n, c, h * w))
key = paddle.transpose(key, (0, 2, 1))
# sim: n, c, c
sim = paddle.bmm(query, key)
# The danet author claims that this can avoid gradient divergence
sim = paddle.max(sim, axis=-1, keepdim=True).expand_as(sim) - sim
sim = F.softmax(sim, axis=-1)
# feat: from (n, c, h * w) to (n, c, h, w)
value = paddle.reshape(x, (n, c, h * w))
feat = paddle.bmm(sim, value)
feat = paddle.reshape(feat, (n, c, h, w))
out = self.gamma * feat + x
return out
class DAHead(nn.Layer):
"""
The Dual attention head.
Args:
num_classes(int): the unique number of target classes.
in_channels(tuple): the number of input channels.
"""
def __init__(self, num_classes, in_channels=None):
super(DAHead, self).__init__()
in_channels = in_channels[-1]
inter_channels = in_channels // 4
self.channel_conv = ConvBNReLU(
in_channels, inter_channels, 3, padding=1)
self.position_conv = ConvBNReLU(
in_channels, inter_channels, 3, padding=1)
self.pam = PAM(inter_channels)
self.cam = CAM()
self.conv1 = ConvBNReLU(inter_channels, inter_channels, 3, padding=1)
self.conv2 = ConvBNReLU(inter_channels, inter_channels, 3, padding=1)
self.aux_head_pam = nn.Sequential(
nn.Dropout2d(0.1), nn.Conv2d(inter_channels, num_classes, 1))
self.aux_head_cam = nn.Sequential(
nn.Dropout2d(0.1), nn.Conv2d(inter_channels, num_classes, 1))
self.cls_head = nn.Sequential(
nn.Dropout2d(0.1), nn.Conv2d(inter_channels, num_classes, 1))
self.init_weight()
def forward(self, x, label=None):
feats = x[-1]
channel_feats = self.channel_conv(feats)
channel_feats = self.cam(channel_feats)
channel_feats = self.conv1(channel_feats)
cam_head = self.aux_head_cam(channel_feats)
position_feats = self.position_conv(feats)
position_feats = self.pam(position_feats)
position_feats = self.conv2(position_feats)
pam_head = self.aux_head_pam(position_feats)
feats_sum = position_feats + channel_feats
cam_logit = self.aux_head_cam(channel_feats)
pam_logit = self.aux_head_cam(position_feats)
logit = self.cls_head(feats_sum)
return [logit, cam_logit, pam_logit]
def init_weight(self):
"""Initialize the parameters of model parts."""
for sublayer in self.sublayers():
if isinstance(sublayer, nn.Conv2d):
param_init.normal_init(sublayer.weight, scale=0.001)
elif isinstance(sublayer, nn.SyncBatchNorm):
param_init.constant_init(sublayer.weight, value=1)
param_init.constant_init(sublayer.bias, value=0)
@manager.MODELS.add_component
class DANet(nn.Layer):
"""
The DANet implementation based on PaddlePaddle.
The original article refers to
Fu, jun, et al. "Dual Attention Network for Scene Segmentation"
(https://arxiv.org/pdf/1809.02983.pdf)
Args:
num_classes(int): the unique number of target classes.
backbone(Paddle.nn.Layer): backbone network.
pretrained(str): the path or url of pretrained model. Default to None.
backbone_indices(tuple): values in the tuple indicate the indices of output of backbone.
Only the last indice is used.
"""
def __init__(self,
num_classes,
backbone,
pretrained=None,
backbone_indices=None):
super(DANet, self).__init__()
self.backbone = backbone
self.backbone_indices = backbone_indices
in_channels = [self.backbone.channels[i] for i in backbone_indices]
self.head = DAHead(num_classes=num_classes, in_channels=in_channels)
self.init_weight(pretrained)
def forward(self, x, label=None):
feats = self.backbone(x)
feats = [feats[i] for i in self.backbone_indices]
preds = self.head(feats, label)
preds = [F.resize_bilinear(pred, x.shape[2:]) for pred in preds]
return preds
def init_weight(self, pretrained=None):
"""
Initialize the parameters of model parts.
Args:
pretrained ([str], optional): the path of pretrained model.. Defaults to None.
"""
if pretrained is not None:
if os.path.exists(pretrained):
utils.load_pretrained_model(self, pretrained)
else:
raise Exception(
'Pretrained model is not found: {}'.format(pretrained))
......@@ -14,36 +14,41 @@
import os
import paddle.fluid as fluid
from paddle.fluid.dygraph import Sequential, Conv2D
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddleseg.cvlibs import manager
from paddleseg.models.common.layer_libs import ConvBnRelu
from paddleseg import utils
from paddleseg.cvlibs import manager, param_init
from paddleseg.models.common.layer_libs import ConvBNReLU, AuxLayer
class SpatialGatherBlock(fluid.dygraph.Layer):
class SpatialGatherBlock(nn.Layer):
"""Aggregation layer to compute the pixel-region representation"""
def forward(self, pixels, regions):
n, c, h, w = pixels.shape
_, k, _, _ = regions.shape
# pixels: from (n, c, h, w) to (n, h*w, c)
pixels = fluid.layers.reshape(pixels, (n, c, h * w))
pixels = fluid.layers.transpose(pixels, (0, 2, 1))
pixels = paddle.reshape(pixels, (n, c, h * w))
pixels = paddle.transpose(pixels, (0, 2, 1))
# regions: from (n, k, h, w) to (n, k, h*w)
regions = fluid.layers.reshape(regions, (n, k, h * w))
regions = fluid.layers.softmax(regions, axis=2)
regions = paddle.reshape(regions, (n, k, h * w))
regions = F.softmax(regions, axis=2)
# feats: from (n, k, c) to (n, c, k, 1)
feats = fluid.layers.matmul(regions, pixels)
feats = fluid.layers.transpose(feats, (0, 2, 1))
feats = fluid.layers.unsqueeze(feats, axes=[-1])
feats = paddle.bmm(regions, pixels)
feats = paddle.transpose(feats, (0, 2, 1))
feats = paddle.unsqueeze(feats, axis=-1)
return feats
class SpatialOCRModule(fluid.dygraph.Layer):
class SpatialOCRModule(nn.Layer):
"""Aggregate the global object representation to update the representation for each pixel"""
def __init__(self,
in_channels,
key_channels,
......@@ -53,163 +58,180 @@ class SpatialOCRModule(fluid.dygraph.Layer):
self.attention_block = ObjectAttentionBlock(in_channels, key_channels)
self.dropout_rate = dropout_rate
self.conv1x1 = Conv2D(2 * in_channels, out_channels, 1)
self.conv1x1 = nn.Sequential(
nn.Conv2d(2 * in_channels, out_channels, 1), nn.Dropout2d(0.1))
def forward(self, pixels, regions):
context = self.attention_block(pixels, regions)
feats = fluid.layers.concat([context, pixels], axis=1)
feats = paddle.concat([context, pixels], axis=1)
feats = self.conv1x1(feats)
feats = fluid.layers.dropout(feats, self.dropout_rate)
return feats
class ObjectAttentionBlock(fluid.dygraph.Layer):
class ObjectAttentionBlock(nn.Layer):
"""A self-attention module."""
def __init__(self, in_channels, key_channels):
super(ObjectAttentionBlock, self).__init__()
self.in_channels = in_channels
self.key_channels = key_channels
self.f_pixel = Sequential(
ConvBnRelu(in_channels, key_channels, 1),
ConvBnRelu(key_channels, key_channels, 1))
self.f_pixel = nn.Sequential(
ConvBNReLU(in_channels, key_channels, 1),
ConvBNReLU(key_channels, key_channels, 1))
self.f_object = Sequential(
ConvBnRelu(in_channels, key_channels, 1),
ConvBnRelu(key_channels, key_channels, 1))
self.f_object = nn.Sequential(
ConvBNReLU(in_channels, key_channels, 1),
ConvBNReLU(key_channels, key_channels, 1))
self.f_down = ConvBnRelu(in_channels, key_channels, 1)
self.f_down = ConvBNReLU(in_channels, key_channels, 1)
self.f_up = ConvBnRelu(key_channels, in_channels, 1)
self.f_up = ConvBNReLU(key_channels, in_channels, 1)
def forward(self, x, proxy):
n, _, h, w = x.shape
# query : from (n, c1, h1, w1) to (n, h1*w1, key_channels)
query = self.f_pixel(x)
query = fluid.layers.reshape(query, (n, self.key_channels, -1))
query = fluid.layers.transpose(query, (0, 2, 1))
query = paddle.reshape(query, (n, self.key_channels, -1))
query = paddle.transpose(query, (0, 2, 1))
# key : from (n, c2, h2, w2) to (n, key_channels, h2*w2)
key = self.f_object(proxy)
key = fluid.layers.reshape(key, (n, self.key_channels, -1))
key = paddle.reshape(key, (n, self.key_channels, -1))
# value : from (n, c2, h2, w2) to (n, h2*w2, key_channels)
value = self.f_down(proxy)
value = fluid.layers.reshape(value, (n, self.key_channels, -1))
value = fluid.layers.transpose(value, (0, 2, 1))
value = paddle.reshape(value, (n, self.key_channels, -1))
value = paddle.transpose(value, (0, 2, 1))
# sim_map (n, h1*w1, h2*w2)
sim_map = fluid.layers.matmul(query, key)
sim_map = paddle.bmm(query, key)
sim_map = (self.key_channels**-.5) * sim_map
sim_map = fluid.layers.softmax(sim_map, axis=-1)
sim_map = F.softmax(sim_map, axis=-1)
# context from (n, h1*w1, key_channels) to (n , out_channels, h1, w1)
context = fluid.layers.matmul(sim_map, value)
context = fluid.layers.transpose(context, (0, 2, 1))
context = fluid.layers.reshape(context, (n, self.key_channels, h, w))
context = paddle.bmm(sim_map, value)
context = paddle.transpose(context, (0, 2, 1))
context = paddle.reshape(context, (n, self.key_channels, h, w))
context = self.f_up(context)
return context
@manager.MODELS.add_component
class OCRNet(fluid.dygraph.Layer):
class OCRHead(nn.Layer):
"""
The Object contextual representation head.
Args:
num_classes(int): the unique number of target classes.
in_channels(tuple): the number of input channels.
ocr_mid_channels(int): the number of middle channels in OCRHead.
ocr_key_channels(int): the number of key channels in ObjectAttentionBlock.
"""
def __init__(self,
num_classes,
backbone,
model_pretrained=None,
in_channels=None,
ocr_mid_channels=512,
ocr_key_channels=256,
ignore_index=255):
super(OCRNet, self).__init__()
ocr_key_channels=256):
super(OCRHead, self).__init__()
self.ignore_index = ignore_index
self.num_classes = num_classes
self.EPS = 1e-5
self.backbone = backbone
self.spatial_gather = SpatialGatherBlock()
self.spatial_ocr = SpatialOCRModule(ocr_mid_channels, ocr_key_channels,
ocr_mid_channels)
self.conv3x3_ocr = ConvBnRelu(
in_channels, ocr_mid_channels, 3, padding=1)
self.cls_head = Conv2D(ocr_mid_channels, self.num_classes, 1)
self.aux_head = Sequential(
ConvBnRelu(in_channels, in_channels, 3, padding=1),
Conv2D(in_channels, self.num_classes, 1))
self.indices = [-2, -1] if len(in_channels) > 1 else [-1, -1]
self.init_weight(model_pretrained)
self.conv3x3_ocr = ConvBNReLU(
in_channels[self.indices[1]], ocr_mid_channels, 3, padding=1)
self.cls_head = nn.Conv2d(ocr_mid_channels, self.num_classes, 1)
self.aux_head = AuxLayer(in_channels[self.indices[0]],
in_channels[self.indices[0]], self.num_classes)
self.init_weight()
def forward(self, x, label=None):
feats = self.backbone(x)
feat_shallow, feat_deep = x[self.indices[0]], x[self.indices[1]]
soft_regions = self.aux_head(feats)
pixels = self.conv3x3_ocr(feats)
soft_regions = self.aux_head(feat_shallow)
pixels = self.conv3x3_ocr(feat_deep)
object_regions = self.spatial_gather(pixels, soft_regions)
ocr = self.spatial_ocr(pixels, object_regions)
logit = self.cls_head(ocr)
logit = fluid.layers.resize_bilinear(logit, x.shape[2:])
if self.training:
soft_regions = fluid.layers.resize_bilinear(soft_regions,
x.shape[2:])
cls_loss = self._get_loss(logit, label)
aux_loss = self._get_loss(soft_regions, label)
return cls_loss + 0.4 * aux_loss
score_map = fluid.layers.softmax(logit, axis=1)
score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1])
pred = fluid.layers.argmax(score_map, axis=3)
pred = fluid.layers.unsqueeze(pred, axes=[3])
return pred, score_map
def init_weight(self, pretrained_model=None):
return [logit, soft_regions]
def init_weight(self):
"""Initialize the parameters of model parts."""
for sublayer in self.sublayers():
if isinstance(sublayer, nn.Conv2d):
param_init.normal_init(sublayer.weight, scale=0.001)
elif isinstance(sublayer, nn.SyncBatchNorm):
param_init.constant_init(sublayer.weight, value=1)
param_init.constant_init(sublayer.bias, value=0)
@manager.MODELS.add_component
class OCRNet(nn.Layer):
"""
The OCRNet implementation based on PaddlePaddle.
The original article refers to
Yuan, Yuhui, et al. "Object-Contextual Representations for Semantic Segmentation"
(https://arxiv.org/pdf/1909.11065.pdf)
Args:
num_classes(int): the unique number of target classes.
backbone(Paddle.nn.Layer): backbone network.
pretrained(str): the path or url of pretrained model. Default to None.
backbone_indices(tuple): two values in the tuple indicate the indices of output of backbone.
the first index will be taken as a deep-supervision feature in auxiliary layer;
the second one will be taken as input of pixel representation.
ocr_mid_channels(int): the number of middle channels in OCRHead.
ocr_key_channels(int): the number of key channels in ObjectAttentionBlock.
"""
def __init__(self,
num_classes,
backbone,
pretrained=None,
backbone_indices=None,
ocr_mid_channels=512,
ocr_key_channels=256):
super(OCRNet, self).__init__()
self.backbone = backbone
self.backbone_indices = backbone_indices
in_channels = [self.backbone.channels[i] for i in backbone_indices]
self.head = OCRHead(
num_classes=num_classes,
in_channels=in_channels,
ocr_mid_channels=ocr_mid_channels,
ocr_key_channels=ocr_key_channels)
self.init_weight(pretrained)
def forward(self, x, label=None):
feats = self.backbone(x)
feats = [feats[i] for i in self.backbone_indices]
preds = self.head(feats, label)
preds = [F.resize_bilinear(pred, x.shape[2:]) for pred in preds]
return preds
def init_weight(self, pretrained=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model.. Defaults to None.
pretrained ([str], optional): the path of pretrained model.. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self, pretrained_model)
if pretrained is not None:
if os.path.exists(pretrained):
utils.load_pretrained_model(self, pretrained)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
def _get_loss(self, logit, label):
"""
compute forward loss of the model
Args:
logit (tensor): the logit of model output
label (tensor): ground truth
Returns:
avg_loss (tensor): forward loss
"""
logit = fluid.layers.transpose(logit, [0, 2, 3, 1])
label = fluid.layers.transpose(label, [0, 2, 3, 1])
mask = label != self.ignore_index
mask = fluid.layers.cast(mask, 'float32')
loss, probs = fluid.layers.softmax_with_cross_entropy(
logit,
label,
ignore_index=self.ignore_index,
return_softmax=True,
axis=-1)
loss = loss * mask
avg_loss = fluid.layers.mean(loss) / (
fluid.layers.mean(mask) + self.EPS)
label.stop_gradient = True
mask.stop_gradient = True
return avg_loss
raise Exception(
'Pretrained model is not found: {}'.format(pretrained))
......@@ -12,13 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import os
import numpy as np
import math
import cv2
import tempfile
import paddle.fluid as fluid
from urllib.parse import urlparse, unquote
from . import logger
import filelock
import paddleseg.env as segenv
from paddleseg.utils import logger
from paddleseg.utils.download import download_file_and_uncompress
@contextlib.contextmanager
def generate_tempdir(directory: str = None, **kwargs):
'''Generate a temporary directory'''
directory = segenv.TMP_HOME if not directory else directory
with tempfile.TemporaryDirectory(dir=directory, **kwargs) as _dir:
yield _dir
def seconds_to_hms(seconds):
......@@ -32,6 +47,18 @@ def seconds_to_hms(seconds):
def load_pretrained_model(model, pretrained_model):
if pretrained_model is not None:
logger.info('Load pretrained model from {}'.format(pretrained_model))
# download pretrained model from url
if urlparse(pretrained_model).netloc:
pretrained_model = unquote(pretrained_model)
savename = pretrained_model.split('/')[-1].split('.')[0]
with generate_tempdir() as _dir:
with filelock.FileLock(os.path.join(segenv.TMP_HOME, savename)):
pretrained_model = download_file_and_uncompress(
pretrained_model,
savepath=_dir,
extrapath=segenv.PRETRAINED_MODEL_HOME,
extraname=savename)
if os.path.exists(pretrained_model):
ckpt_path = os.path.join(pretrained_model, 'model')
try:
......
......@@ -112,9 +112,10 @@ def main(args):
val_dataset = cfg.val_dataset if args.do_eval else None
losses = cfg.loss
print('---------------Config Information---------------')
print(cfg)
print('------------------------------------------------')
msg = '\n---------------Config Information---------------\n'
msg += str(cfg)
msg += '------------------------------------------------'
logger.info(msg)
train(
cfg.model,
......
......@@ -19,7 +19,7 @@ from paddle.distributed import ParallelEnv
import paddleseg
from paddleseg.cvlibs import manager
from paddleseg.utils import get_environ_info, Config
from paddleseg.utils import get_environ_info, Config, logger
from paddleseg.core import evaluate
......@@ -56,9 +56,10 @@ def main(args):
'The verification dataset is not specified in the configuration file.'
)
print('---------------Config Information---------------')
print(cfg)
print('------------------------------------------------')
msg = '\n---------------Config Information---------------\n'
msg += str(cfg)
msg += '------------------------------------------------'
logger.info(msg)
evaluate(
cfg.model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册