提交 4b5665d0 编写于 作者: M michaelowenliu

add ocrnet

...@@ -4,7 +4,7 @@ learning_rate: 0.01 ...@@ -4,7 +4,7 @@ learning_rate: 0.01
train_dataset: train_dataset:
type: Cityscapes type: Cityscapes
dataset_root: data/cityscapes dataset_root: /mnt/liuyi22/.cache/paddle/dataset/cityscapes
transforms: transforms:
- type: ResizeStepScaling - type: ResizeStepScaling
min_scale_factor: 0.5 min_scale_factor: 0.5
...@@ -18,7 +18,7 @@ train_dataset: ...@@ -18,7 +18,7 @@ train_dataset:
val_dataset: val_dataset:
type: Cityscapes type: Cityscapes
dataset_root: data/cityscapes dataset_root: /mnt/liuyi22/.cache/paddle/dataset/cityscapes
transforms: transforms:
- type: Normalize - type: Normalize
mode: val mode: val
......
...@@ -87,7 +87,8 @@ def seg_train(model, ...@@ -87,7 +87,8 @@ def seg_train(model,
out_labels = ["loss", "reader_cost", "batch_cost"] out_labels = ["loss", "reader_cost", "batch_cost"]
base_logger = callbacks.BaseLogger(period=log_iters) base_logger = callbacks.BaseLogger(period=log_iters)
train_logger = callbacks.TrainLogger(log_freq=log_iters) train_logger = callbacks.TrainLogger(log_freq=log_iters)
model_ckpt = callbacks.ModelCheckpoint(save_dir, save_params_only=False, period=save_interval_iters) model_ckpt = callbacks.ModelCheckpoint(
save_dir, save_params_only=False, period=save_interval_iters)
vdl = callbacks.VisualDL(log_dir=os.path.join(save_dir, "log")) vdl = callbacks.VisualDL(log_dir=os.path.join(save_dir, "log"))
cbks_list = [base_logger, train_logger, model_ckpt, vdl] cbks_list = [base_logger, train_logger, model_ckpt, vdl]
...@@ -120,7 +121,7 @@ def seg_train(model, ...@@ -120,7 +121,7 @@ def seg_train(model,
iter += 1 iter += 1
if iter > iters: if iter > iters:
break break
logs["reader_cost"] = timer.elapsed_time() logs["reader_cost"] = timer.elapsed_time()
############## 2 ################ ############## 2 ################
cbks.on_iter_begin(iter, logs) cbks.on_iter_begin(iter, logs)
...@@ -136,7 +137,7 @@ def seg_train(model, ...@@ -136,7 +137,7 @@ def seg_train(model,
loss = ddp_model.scale_loss(loss) loss = ddp_model.scale_loss(loss)
loss.backward() loss.backward()
ddp_model.apply_collective_grads() ddp_model.apply_collective_grads()
else: else:
logits = model(images) logits = model(images)
loss = loss_computation(logits, labels, losses) loss = loss_computation(logits, labels, losses)
...@@ -148,7 +149,7 @@ def seg_train(model, ...@@ -148,7 +149,7 @@ def seg_train(model,
model.clear_gradients() model.clear_gradients()
logs['loss'] = loss.numpy()[0] logs['loss'] = loss.numpy()[0]
logs["batch_cost"] = timer.elapsed_time() logs["batch_cost"] = timer.elapsed_time()
############## 3 ################ ############## 3 ################
...@@ -159,4 +160,6 @@ def seg_train(model, ...@@ -159,4 +160,6 @@ def seg_train(model,
############### 4 ############### ############### 4 ###############
cbks.on_train_end(logs) cbks.on_train_end(logs)
#################################
\ No newline at end of file
#################################
...@@ -67,7 +67,7 @@ def evaluate(model, ...@@ -67,7 +67,7 @@ def evaluate(model,
pred = pred[np.newaxis, :, :, np.newaxis] pred = pred[np.newaxis, :, :, np.newaxis]
pred = pred.astype('int64') pred = pred.astype('int64')
mask = label != ignore_index mask = label != ignore_index
# To-DO Test Execution Time
conf_mat.calculate(pred=pred, label=label, ignore=mask) conf_mat.calculate(pred=pred, label=label, ignore=mask)
_, iou = conf_mat.mean_iou() _, iou = conf_mat.mean_iou()
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import time import time
...@@ -24,6 +23,7 @@ from visualdl import LogWriter ...@@ -24,6 +23,7 @@ from visualdl import LogWriter
from paddleseg.utils.progbar import Progbar from paddleseg.utils.progbar import Progbar
import paddleseg.utils.logger as logger import paddleseg.utils.logger as logger
class CallbackList(object): class CallbackList(object):
"""Container abstracting a list of callbacks. """Container abstracting a list of callbacks.
# Arguments # Arguments
...@@ -44,7 +44,7 @@ class CallbackList(object): ...@@ -44,7 +44,7 @@ class CallbackList(object):
def set_model(self, model): def set_model(self, model):
for callback in self.callbacks: for callback in self.callbacks:
callback.set_model(model) callback.set_model(model)
def set_optimizer(self, optimizer): def set_optimizer(self, optimizer):
for callback in self.callbacks: for callback in self.callbacks:
callback.set_optimizer(optimizer) callback.set_optimizer(optimizer)
...@@ -82,6 +82,7 @@ class CallbackList(object): ...@@ -82,6 +82,7 @@ class CallbackList(object):
def __iter__(self): def __iter__(self):
return iter(self.callbacks) return iter(self.callbacks)
class Callback(object): class Callback(object):
"""Abstract base class used to build new callbacks. """Abstract base class used to build new callbacks.
""" """
...@@ -94,7 +95,7 @@ class Callback(object): ...@@ -94,7 +95,7 @@ class Callback(object):
def set_model(self, model): def set_model(self, model):
self.model = model self.model = model
def set_optimizer(self, optimizer): def set_optimizer(self, optimizer):
self.optimizer = optimizer self.optimizer = optimizer
...@@ -110,18 +111,18 @@ class Callback(object): ...@@ -110,18 +111,18 @@ class Callback(object):
def on_train_end(self, logs=None): def on_train_end(self, logs=None):
pass pass
class BaseLogger(Callback):
class BaseLogger(Callback):
def __init__(self, period=10): def __init__(self, period=10):
super(BaseLogger, self).__init__() super(BaseLogger, self).__init__()
self.period = period self.period = period
def _reset(self): def _reset(self):
self.totals = {} self.totals = {}
def on_train_begin(self, logs=None): def on_train_begin(self, logs=None):
self.totals = {} self.totals = {}
def on_iter_end(self, iter, logs=None): def on_iter_end(self, iter, logs=None):
logs = logs or {} logs = logs or {}
#(iter - 1) // iters_per_epoch + 1 #(iter - 1) // iters_per_epoch + 1
...@@ -132,13 +133,13 @@ class BaseLogger(Callback): ...@@ -132,13 +133,13 @@ class BaseLogger(Callback):
self.totals[k] = v self.totals[k] = v
if iter % self.period == 0 and ParallelEnv().local_rank == 0: if iter % self.period == 0 and ParallelEnv().local_rank == 0:
for k in self.totals: for k in self.totals:
logs[k] = self.totals[k] / self.period logs[k] = self.totals[k] / self.period
self._reset() self._reset()
class TrainLogger(Callback):
class TrainLogger(Callback):
def __init__(self, log_freq=10): def __init__(self, log_freq=10):
self.log_freq = log_freq self.log_freq = log_freq
...@@ -154,7 +155,7 @@ class TrainLogger(Callback): ...@@ -154,7 +155,7 @@ class TrainLogger(Callback):
return result.format(*arr) return result.format(*arr)
def on_iter_end(self, iter, logs=None): def on_iter_end(self, iter, logs=None):
if iter % self.log_freq == 0 and ParallelEnv().local_rank == 0: if iter % self.log_freq == 0 and ParallelEnv().local_rank == 0:
total_iters = self.params["total_iters"] total_iters = self.params["total_iters"]
iters_per_epoch = self.params["iters_per_epoch"] iters_per_epoch = self.params["iters_per_epoch"]
...@@ -167,49 +168,50 @@ class TrainLogger(Callback): ...@@ -167,49 +168,50 @@ class TrainLogger(Callback):
reader_cost = logs["reader_cost"] reader_cost = logs["reader_cost"]
logger.info( logger.info(
"[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.4f} | ETA {}". "[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.4f} | ETA {}"
format(current_epoch, iter, total_iters, .format(current_epoch, iter, total_iters, loss, lr, batch_cost,
loss, lr, batch_cost, reader_cost, eta)) reader_cost, eta))
class ProgbarLogger(Callback):
class ProgbarLogger(Callback):
def __init__(self): def __init__(self):
super(ProgbarLogger, self).__init__() super(ProgbarLogger, self).__init__()
def on_train_begin(self, logs=None): def on_train_begin(self, logs=None):
self.verbose = self.params["verbose"] self.verbose = self.params["verbose"]
self.total_iters = self.params["total_iters"] self.total_iters = self.params["total_iters"]
self.target = self.params["total_iters"] self.target = self.params["total_iters"]
self.progbar = Progbar(target=self.target, verbose=self.verbose) self.progbar = Progbar(target=self.target, verbose=self.verbose)
self.seen = 0 self.seen = 0
self.log_values = [] self.log_values = []
def on_iter_begin(self, iter, logs=None): def on_iter_begin(self, iter, logs=None):
#self.seen = 0 #self.seen = 0
if self.seen < self.target: if self.seen < self.target:
self.log_values = [] self.log_values = []
def on_iter_end(self, iter, logs=None): def on_iter_end(self, iter, logs=None):
logs = logs or {} logs = logs or {}
self.seen += 1 self.seen += 1
for k in self.params['metrics']: for k in self.params['metrics']:
if k in logs: if k in logs:
self.log_values.append((k, logs[k])) self.log_values.append((k, logs[k]))
#if self.verbose and self.seen < self.target and ParallelEnv.local_rank == 0: #if self.verbose and self.seen < self.target and ParallelEnv.local_rank == 0:
#print(self.log_values) #print(self.log_values)
if self.seen < self.target: if self.seen < self.target:
self.progbar.update(self.seen, self.log_values) self.progbar.update(self.seen, self.log_values)
class ModelCheckpoint(Callback): class ModelCheckpoint(Callback):
def __init__(self,
save_dir,
monitor="miou",
save_best_only=False,
save_params_only=True,
mode="max",
period=1):
def __init__(self, save_dir, monitor="miou",
save_best_only=False, save_params_only=True,
mode="max", period=1):
super(ModelCheckpoint, self).__init__() super(ModelCheckpoint, self).__init__()
self.monitor = monitor self.monitor = monitor
self.save_dir = save_dir self.save_dir = save_dir
...@@ -241,7 +243,7 @@ class ModelCheckpoint(Callback): ...@@ -241,7 +243,7 @@ class ModelCheckpoint(Callback):
current_save_dir = os.path.join(self.save_dir, "iter_{}".format(iter)) current_save_dir = os.path.join(self.save_dir, "iter_{}".format(iter))
current_save_dir = os.path.abspath(current_save_dir) current_save_dir = os.path.abspath(current_save_dir)
#if self.iters_since_last_save % self.period and ParallelEnv().local_rank == 0: #if self.iters_since_last_save % self.period and ParallelEnv().local_rank == 0:
#self.iters_since_last_save = 0 #self.iters_since_last_save = 0
if iter % self.period == 0 and ParallelEnv().local_rank == 0: if iter % self.period == 0 and ParallelEnv().local_rank == 0:
if self.verbose > 0: if self.verbose > 0:
print("iter {iter_num}: saving model to {path}".format( print("iter {iter_num}: saving model to {path}".format(
...@@ -252,11 +254,9 @@ class ModelCheckpoint(Callback): ...@@ -252,11 +254,9 @@ class ModelCheckpoint(Callback):
if not self.save_params_only: if not self.save_params_only:
paddle.save(self.optimizer.state_dict(), filepath) paddle.save(self.optimizer.state_dict(), filepath)
class VisualDL(Callback): class VisualDL(Callback):
def __init__(self, log_dir="./log", freq=1): def __init__(self, log_dir="./log", freq=1):
super(VisualDL, self).__init__() super(VisualDL, self).__init__()
self.log_dir = log_dir self.log_dir = log_dir
...@@ -274,4 +274,4 @@ class VisualDL(Callback): ...@@ -274,4 +274,4 @@ class VisualDL(Callback):
self.writer.flush() self.writer.flush()
def on_train_end(self, logs=None): def on_train_end(self, logs=None):
self.writer.close() self.writer.close()
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddleseg.utils import utils
from paddleseg.cvlibs import manager, param_init
from paddleseg.models.common.layer_libs import ConvBNReLU
class PAM(nn.Layer):
"""Position attention module"""
def __init__(self, in_channels):
super(PAM, self).__init__()
mid_channels = in_channels // 8
self.query_conv = nn.Conv2d(in_channels, mid_channels, 1, 1)
self.key_conv = nn.Conv2d(in_channels, mid_channels, 1, 1)
self.value_conv = nn.Conv2d(in_channels, in_channels, 1, 1)
self.gamma = self.create_parameter(
shape=[1],
dtype='float32',
default_initializer=nn.initializer.Constant(0))
def forward(self, x):
n, _, h, w = x.shape
# query: n, h * w, c1
query = self.query_conv(x)
query = paddle.reshape(query, (n, -1, h * w))
query = paddle.transpose(query, (0, 2, 1))
# key: n, c1, h * w
key = self.key_conv(x)
key = paddle.reshape(key, (n, -1, h * w))
# sim: n, h * w, h * w
sim = paddle.bmm(query, key)
sim = F.softmax(sim, axis=-1)
value = self.value_conv(x)
value = paddle.reshape(value, (n, -1, h * w))
sim = paddle.transpose(sim, (0, 2, 1))
# feat: from (n, c2, h * w) -> (n, c2, h, w)
feat = paddle.bmm(value, sim)
feat = paddle.reshape(feat, (n, -1, h, w))
out = self.gamma * feat + x
return out
class CAM(nn.Layer):
"""Channel attention module"""
def __init__(self):
super(CAM, self).__init__()
self.gamma = self.create_parameter(
shape=[1],
dtype='float32',
default_initializer=nn.initializer.Constant(0))
def forward(self, x):
n, c, h, w = x.shape
# query: n, c, h * w
query = paddle.reshape(x, (n, c, h * w))
# key: n, h * w, c
key = paddle.reshape(x, (n, c, h * w))
key = paddle.transpose(key, (0, 2, 1))
# sim: n, c, c
sim = paddle.bmm(query, key)
# The danet author claims that this can avoid gradient divergence
sim = paddle.max(sim, axis=-1, keepdim=True).expand_as(sim) - sim
sim = F.softmax(sim, axis=-1)
# feat: from (n, c, h * w) to (n, c, h, w)
value = paddle.reshape(x, (n, c, h * w))
feat = paddle.bmm(sim, value)
feat = paddle.reshape(feat, (n, c, h, w))
out = self.gamma * feat + x
return out
class DAHead(nn.Layer):
"""
The Dual attention head.
Args:
num_classes(int): the unique number of target classes.
in_channels(tuple): the number of input channels.
"""
def __init__(self, num_classes, in_channels=None):
super(DAHead, self).__init__()
in_channels = in_channels[-1]
inter_channels = in_channels // 4
self.channel_conv = ConvBNReLU(
in_channels, inter_channels, 3, padding=1)
self.position_conv = ConvBNReLU(
in_channels, inter_channels, 3, padding=1)
self.pam = PAM(inter_channels)
self.cam = CAM()
self.conv1 = ConvBNReLU(inter_channels, inter_channels, 3, padding=1)
self.conv2 = ConvBNReLU(inter_channels, inter_channels, 3, padding=1)
self.aux_head_pam = nn.Sequential(
nn.Dropout2d(0.1), nn.Conv2d(inter_channels, num_classes, 1))
self.aux_head_cam = nn.Sequential(
nn.Dropout2d(0.1), nn.Conv2d(inter_channels, num_classes, 1))
self.cls_head = nn.Sequential(
nn.Dropout2d(0.1), nn.Conv2d(inter_channels, num_classes, 1))
self.init_weight()
def forward(self, x, label=None):
feats = x[-1]
channel_feats = self.channel_conv(feats)
channel_feats = self.cam(channel_feats)
channel_feats = self.conv1(channel_feats)
cam_head = self.aux_head_cam(channel_feats)
position_feats = self.position_conv(feats)
position_feats = self.pam(position_feats)
position_feats = self.conv2(position_feats)
pam_head = self.aux_head_pam(position_feats)
feats_sum = position_feats + channel_feats
cam_logit = self.aux_head_cam(channel_feats)
pam_logit = self.aux_head_cam(position_feats)
logit = self.cls_head(feats_sum)
return [logit, cam_logit, pam_logit]
def init_weight(self):
"""Initialize the parameters of model parts."""
for sublayer in self.sublayers():
if isinstance(sublayer, nn.Conv2d):
param_init.normal_init(sublayer.weight, scale=0.001)
elif isinstance(sublayer, nn.SyncBatchNorm):
param_init.constant_init(sublayer.weight, value=1)
param_init.constant_init(sublayer.bias, value=0)
@manager.MODELS.add_component
class DANet(nn.Layer):
"""
The DANet implementation based on PaddlePaddle.
The original article refers to
Fu, jun, et al. "Dual Attention Network for Scene Segmentation"
(https://arxiv.org/pdf/1809.02983.pdf)
Args:
num_classes(int): the unique number of target classes.
backbone(Paddle.nn.Layer): backbone network.
pretrained(str): the path or url of pretrained model. Default to None.
backbone_indices(tuple): values in the tuple indicate the indices of output of backbone.
Only the last indice is used.
"""
def __init__(self,
num_classes,
backbone,
pretrained=None,
backbone_indices=None):
super(DANet, self).__init__()
self.backbone = backbone
self.backbone_indices = backbone_indices
in_channels = [self.backbone.channels[i] for i in backbone_indices]
self.head = DAHead(num_classes=num_classes, in_channels=in_channels)
self.init_weight(pretrained)
def forward(self, x, label=None):
feats = self.backbone(x)
feats = [feats[i] for i in self.backbone_indices]
preds = self.head(feats, label)
preds = [F.resize_bilinear(pred, x.shape[2:]) for pred in preds]
return preds
def init_weight(self, pretrained=None):
"""
Initialize the parameters of model parts.
Args:
pretrained ([str], optional): the path of pretrained model.. Defaults to None.
"""
if pretrained is not None:
if os.path.exists(pretrained):
utils.load_pretrained_model(self, pretrained)
else:
raise Exception(
'Pretrained model is not found: {}'.format(pretrained))
...@@ -14,36 +14,41 @@ ...@@ -14,36 +14,41 @@
import os import os
import paddle.fluid as fluid import paddle
from paddle.fluid.dygraph import Sequential, Conv2D import paddle.nn as nn
import paddle.nn.functional as F
from paddleseg.cvlibs import manager
from paddleseg.models.common.layer_libs import ConvBNReLU
from paddleseg import utils from paddleseg import utils
from paddleseg.cvlibs import manager, param_init
from paddleseg.models.common.layer_libs import ConvBNReLU, AuxLayer
class SpatialGatherBlock(fluid.dygraph.Layer): class SpatialGatherBlock(nn.Layer):
"""Aggregation layer to compute the pixel-region representation"""
def forward(self, pixels, regions): def forward(self, pixels, regions):
n, c, h, w = pixels.shape n, c, h, w = pixels.shape
_, k, _, _ = regions.shape _, k, _, _ = regions.shape
# pixels: from (n, c, h, w) to (n, h*w, c) # pixels: from (n, c, h, w) to (n, h*w, c)
pixels = fluid.layers.reshape(pixels, (n, c, h * w)) pixels = paddle.reshape(pixels, (n, c, h * w))
pixels = fluid.layers.transpose(pixels, (0, 2, 1)) pixels = paddle.transpose(pixels, (0, 2, 1))
# regions: from (n, k, h, w) to (n, k, h*w) # regions: from (n, k, h, w) to (n, k, h*w)
regions = fluid.layers.reshape(regions, (n, k, h * w)) regions = paddle.reshape(regions, (n, k, h * w))
regions = fluid.layers.softmax(regions, axis=2) regions = F.softmax(regions, axis=2)
# feats: from (n, k, c) to (n, c, k, 1) # feats: from (n, k, c) to (n, c, k, 1)
feats = fluid.layers.matmul(regions, pixels) feats = paddle.bmm(regions, pixels)
feats = fluid.layers.transpose(feats, (0, 2, 1)) feats = paddle.transpose(feats, (0, 2, 1))
feats = fluid.layers.unsqueeze(feats, axes=[-1]) feats = paddle.unsqueeze(feats, axis=-1)
return feats return feats
class SpatialOCRModule(fluid.dygraph.Layer): class SpatialOCRModule(nn.Layer):
"""Aggregate the global object representation to update the representation for each pixel"""
def __init__(self, def __init__(self,
in_channels, in_channels,
key_channels, key_channels,
...@@ -53,30 +58,31 @@ class SpatialOCRModule(fluid.dygraph.Layer): ...@@ -53,30 +58,31 @@ class SpatialOCRModule(fluid.dygraph.Layer):
self.attention_block = ObjectAttentionBlock(in_channels, key_channels) self.attention_block = ObjectAttentionBlock(in_channels, key_channels)
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.conv1x1 = Conv2D(2 * in_channels, out_channels, 1) self.conv1x1 = nn.Sequential(
nn.Conv2d(2 * in_channels, out_channels, 1), nn.Dropout2d(0.1))
def forward(self, pixels, regions): def forward(self, pixels, regions):
context = self.attention_block(pixels, regions) context = self.attention_block(pixels, regions)
feats = fluid.layers.concat([context, pixels], axis=1) feats = paddle.concat([context, pixels], axis=1)
feats = self.conv1x1(feats) feats = self.conv1x1(feats)
feats = fluid.layers.dropout(feats, self.dropout_rate)
return feats return feats
class ObjectAttentionBlock(fluid.dygraph.Layer): class ObjectAttentionBlock(nn.Layer):
"""A self-attention module."""
def __init__(self, in_channels, key_channels): def __init__(self, in_channels, key_channels):
super(ObjectAttentionBlock, self).__init__() super(ObjectAttentionBlock, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.key_channels = key_channels self.key_channels = key_channels
self.f_pixel = Sequential( self.f_pixel = nn.Sequential(
ConvBNReLU(in_channels, key_channels, 1), ConvBNReLU(in_channels, key_channels, 1),
ConvBNReLU(key_channels, key_channels, 1)) ConvBNReLU(key_channels, key_channels, 1))
self.f_object = Sequential( self.f_object = nn.Sequential(
ConvBNReLU(in_channels, key_channels, 1), ConvBNReLU(in_channels, key_channels, 1),
ConvBNReLU(key_channels, key_channels, 1)) ConvBNReLU(key_channels, key_channels, 1))
...@@ -89,127 +95,140 @@ class ObjectAttentionBlock(fluid.dygraph.Layer): ...@@ -89,127 +95,140 @@ class ObjectAttentionBlock(fluid.dygraph.Layer):
# query : from (n, c1, h1, w1) to (n, h1*w1, key_channels) # query : from (n, c1, h1, w1) to (n, h1*w1, key_channels)
query = self.f_pixel(x) query = self.f_pixel(x)
query = fluid.layers.reshape(query, (n, self.key_channels, -1)) query = paddle.reshape(query, (n, self.key_channels, -1))
query = fluid.layers.transpose(query, (0, 2, 1)) query = paddle.transpose(query, (0, 2, 1))
# key : from (n, c2, h2, w2) to (n, key_channels, h2*w2) # key : from (n, c2, h2, w2) to (n, key_channels, h2*w2)
key = self.f_object(proxy) key = self.f_object(proxy)
key = fluid.layers.reshape(key, (n, self.key_channels, -1)) key = paddle.reshape(key, (n, self.key_channels, -1))
# value : from (n, c2, h2, w2) to (n, h2*w2, key_channels) # value : from (n, c2, h2, w2) to (n, h2*w2, key_channels)
value = self.f_down(proxy) value = self.f_down(proxy)
value = fluid.layers.reshape(value, (n, self.key_channels, -1)) value = paddle.reshape(value, (n, self.key_channels, -1))
value = fluid.layers.transpose(value, (0, 2, 1)) value = paddle.transpose(value, (0, 2, 1))
# sim_map (n, h1*w1, h2*w2) # sim_map (n, h1*w1, h2*w2)
sim_map = fluid.layers.matmul(query, key) sim_map = paddle.bmm(query, key)
sim_map = (self.key_channels**-.5) * sim_map sim_map = (self.key_channels**-.5) * sim_map
sim_map = fluid.layers.softmax(sim_map, axis=-1) sim_map = F.softmax(sim_map, axis=-1)
# context from (n, h1*w1, key_channels) to (n , out_channels, h1, w1) # context from (n, h1*w1, key_channels) to (n , out_channels, h1, w1)
context = fluid.layers.matmul(sim_map, value) context = paddle.bmm(sim_map, value)
context = fluid.layers.transpose(context, (0, 2, 1)) context = paddle.transpose(context, (0, 2, 1))
context = fluid.layers.reshape(context, (n, self.key_channels, h, w)) context = paddle.reshape(context, (n, self.key_channels, h, w))
context = self.f_up(context) context = self.f_up(context)
return context return context
@manager.MODELS.add_component class OCRHead(nn.Layer):
class OCRNet(fluid.dygraph.Layer): """
The Object contextual representation head.
Args:
num_classes(int): the unique number of target classes.
in_channels(tuple): the number of input channels.
ocr_mid_channels(int): the number of middle channels in OCRHead.
ocr_key_channels(int): the number of key channels in ObjectAttentionBlock.
"""
def __init__(self, def __init__(self,
num_classes, num_classes,
backbone,
model_pretrained=None,
in_channels=None, in_channels=None,
ocr_mid_channels=512, ocr_mid_channels=512,
ocr_key_channels=256, ocr_key_channels=256):
ignore_index=255): super(OCRHead, self).__init__()
super(OCRNet, self).__init__()
self.ignore_index = ignore_index
self.num_classes = num_classes self.num_classes = num_classes
self.EPS = 1e-5
self.backbone = backbone
self.spatial_gather = SpatialGatherBlock() self.spatial_gather = SpatialGatherBlock()
self.spatial_ocr = SpatialOCRModule(ocr_mid_channels, ocr_key_channels, self.spatial_ocr = SpatialOCRModule(ocr_mid_channels, ocr_key_channels,
ocr_mid_channels) ocr_mid_channels)
self.conv3x3_ocr = ConvBNReLU(
in_channels, ocr_mid_channels, 3, padding=1)
self.cls_head = Conv2D(ocr_mid_channels, self.num_classes, 1)
self.aux_head = Sequential( self.indices = [-2, -1] if len(in_channels) > 1 else [-1, -1]
ConvBNReLU(in_channels, in_channels, 3, padding=1),
Conv2D(in_channels, self.num_classes, 1))
self.init_weight(model_pretrained) self.conv3x3_ocr = ConvBNReLU(
in_channels[self.indices[1]], ocr_mid_channels, 3, padding=1)
self.cls_head = nn.Conv2d(ocr_mid_channels, self.num_classes, 1)
self.aux_head = AuxLayer(in_channels[self.indices[0]],
in_channels[self.indices[0]], self.num_classes)
self.init_weight()
def forward(self, x, label=None): def forward(self, x, label=None):
feats = self.backbone(x) feat_shallow, feat_deep = x[self.indices[0]], x[self.indices[1]]
soft_regions = self.aux_head(feats) soft_regions = self.aux_head(feat_shallow)
pixels = self.conv3x3_ocr(feats) pixels = self.conv3x3_ocr(feat_deep)
object_regions = self.spatial_gather(pixels, soft_regions) object_regions = self.spatial_gather(pixels, soft_regions)
ocr = self.spatial_ocr(pixels, object_regions) ocr = self.spatial_ocr(pixels, object_regions)
logit = self.cls_head(ocr) logit = self.cls_head(ocr)
logit = fluid.layers.resize_bilinear(logit, x.shape[2:]) return [logit, soft_regions]
if self.training: def init_weight(self):
soft_regions = fluid.layers.resize_bilinear(soft_regions, """Initialize the parameters of model parts."""
x.shape[2:]) for sublayer in self.sublayers():
cls_loss = self._get_loss(logit, label) if isinstance(sublayer, nn.Conv2d):
aux_loss = self._get_loss(soft_regions, label) param_init.normal_init(sublayer.weight, scale=0.001)
return cls_loss + 0.4 * aux_loss elif isinstance(sublayer, nn.SyncBatchNorm):
param_init.constant_init(sublayer.weight, value=1)
score_map = fluid.layers.softmax(logit, axis=1) param_init.constant_init(sublayer.bias, value=0)
score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1])
pred = fluid.layers.argmax(score_map, axis=3)
pred = fluid.layers.unsqueeze(pred, axes=[3]) @manager.MODELS.add_component
return pred, score_map class OCRNet(nn.Layer):
"""
def init_weight(self, pretrained_model=None): The OCRNet implementation based on PaddlePaddle.
The original article refers to
Yuan, Yuhui, et al. "Object-Contextual Representations for Semantic Segmentation"
(https://arxiv.org/pdf/1909.11065.pdf)
Args:
num_classes(int): the unique number of target classes.
backbone(Paddle.nn.Layer): backbone network.
pretrained(str): the path or url of pretrained model. Default to None.
backbone_indices(tuple): two values in the tuple indicate the indices of output of backbone.
the first index will be taken as a deep-supervision feature in auxiliary layer;
the second one will be taken as input of pixel representation.
ocr_mid_channels(int): the number of middle channels in OCRHead.
ocr_key_channels(int): the number of key channels in ObjectAttentionBlock.
"""
def __init__(self,
num_classes,
backbone,
pretrained=None,
backbone_indices=None,
ocr_mid_channels=512,
ocr_key_channels=256):
super(OCRNet, self).__init__()
self.backbone = backbone
self.backbone_indices = backbone_indices
in_channels = [self.backbone.channels[i] for i in backbone_indices]
self.head = OCRHead(
num_classes=num_classes,
in_channels=in_channels,
ocr_mid_channels=ocr_mid_channels,
ocr_key_channels=ocr_key_channels)
self.init_weight(pretrained)
def forward(self, x, label=None):
feats = self.backbone(x)
feats = [feats[i] for i in self.backbone_indices]
preds = self.head(feats, label)
preds = [F.resize_bilinear(pred, x.shape[2:]) for pred in preds]
return preds
def init_weight(self, pretrained=None):
""" """
Initialize the parameters of model parts. Initialize the parameters of model parts.
Args: Args:
pretrained_model ([str], optional): the path of pretrained model.. Defaults to None. pretrained ([str], optional): the path of pretrained model.. Defaults to None.
""" """
if pretrained_model is not None: if pretrained is not None:
if os.path.exists(pretrained_model): if os.path.exists(pretrained):
utils.load_pretrained_model(self, pretrained_model) utils.load_pretrained_model(self, pretrained)
else: else:
raise Exception('Pretrained model is not found: {}'.format( raise Exception(
pretrained_model)) 'Pretrained model is not found: {}'.format(pretrained))
\ No newline at end of file
def _get_loss(self, logit, label):
"""
compute forward loss of the model
Args:
logit (tensor): the logit of model output
label (tensor): ground truth
Returns:
avg_loss (tensor): forward loss
"""
logit = fluid.layers.transpose(logit, [0, 2, 3, 1])
label = fluid.layers.transpose(label, [0, 2, 3, 1])
mask = label != self.ignore_index
mask = fluid.layers.cast(mask, 'float32')
loss, probs = fluid.layers.softmax_with_cross_entropy(
logit,
label,
ignore_index=self.ignore_index,
return_softmax=True,
axis=-1)
loss = loss * mask
avg_loss = fluid.layers.mean(loss) / (
fluid.layers.mean(mask) + self.EPS)
label.stop_gradient = True
mask.stop_gradient = True
return avg_loss
...@@ -41,7 +41,7 @@ class ConfusionMatrix(object): ...@@ -41,7 +41,7 @@ class ConfusionMatrix(object):
label = np.asarray(label)[mask] label = np.asarray(label)[mask]
pred = np.asarray(pred)[mask] pred = np.asarray(pred)[mask]
one = np.ones_like(pred) one = np.ones_like(pred)
# Accumuate ([row=label, col=pred], 1) into sparse matrix # Accumuate ([row=label, col=pred], 1) into sparse
spm = csr_matrix((one, (label, pred)), spm = csr_matrix((one, (label, pred)),
shape=(self.num_classes, self.num_classes)) shape=(self.num_classes, self.num_classes))
spm = spm.todense() spm = spm.todense()
......
...@@ -17,8 +17,9 @@ import time ...@@ -17,8 +17,9 @@ import time
import numpy as np import numpy as np
class Progbar(object): class Progbar(object):
"""Displays a progress bar. """Displays a progress bar.
refers to https://github.com/keras-team/keras/blob/keras-2/keras/utils/generic_utils.py refers to https://github.com/keras-team/keras/blob/keras-2/keras/utils/generic_utils.py
Arguments: Arguments:
target: Total number of steps expected, None if unknown. target: Total number of steps expected, None if unknown.
...@@ -31,39 +32,39 @@ class Progbar(object): ...@@ -31,39 +32,39 @@ class Progbar(object):
unit_name: Display name for step counts (usually "step" or "sample"). unit_name: Display name for step counts (usually "step" or "sample").
""" """
def __init__(self, def __init__(self,
target, target,
width=30, width=30,
verbose=1, verbose=1,
interval=0.05, interval=0.05,
stateful_metrics=None, stateful_metrics=None,
unit_name='step'): unit_name='step'):
self.target = target self.target = target
self.width = width self.width = width
self.verbose = verbose self.verbose = verbose
self.interval = interval self.interval = interval
self.unit_name = unit_name self.unit_name = unit_name
if stateful_metrics: if stateful_metrics:
self.stateful_metrics = set(stateful_metrics) self.stateful_metrics = set(stateful_metrics)
else: else:
self.stateful_metrics = set() self.stateful_metrics = set()
self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and self._dynamic_display = ((hasattr(sys.stdout, 'isatty')
sys.stdout.isatty()) or and sys.stdout.isatty())
'ipykernel' in sys.modules or or 'ipykernel' in sys.modules
'posix' in sys.modules or or 'posix' in sys.modules
'PYCHARM_HOSTED' in os.environ) or 'PYCHARM_HOSTED' in os.environ)
self._total_width = 0 self._total_width = 0
self._seen_so_far = 0 self._seen_so_far = 0
# We use a dict + list to avoid garbage collection # We use a dict + list to avoid garbage collection
# issues found in OrderedDict # issues found in OrderedDict
self._values = {} self._values = {}
self._values_order = [] self._values_order = []
self._start = time.time() self._start = time.time()
self._last_update = 0 self._last_update = 0
def update(self, current, values=None, finalize=None): def update(self, current, values=None, finalize=None):
"""Updates the progress bar. """Updates the progress bar.
Arguments: Arguments:
current: Index of current step. current: Index of current step.
values: List of tuples: `(name, value_for_last_step)`. If `name` is in values: List of tuples: `(name, value_for_last_step)`. If `name` is in
...@@ -72,129 +73,131 @@ class Progbar(object): ...@@ -72,129 +73,131 @@ class Progbar(object):
finalize: Whether this is the last update for the progress bar. If finalize: Whether this is the last update for the progress bar. If
`None`, defaults to `current >= self.target`. `None`, defaults to `current >= self.target`.
""" """
if finalize is None: if finalize is None:
if self.target is None: if self.target is None:
finalize = False finalize = False
else: else:
finalize = current >= self.target finalize = current >= self.target
values = values or [] values = values or []
for k, v in values: for k, v in values:
if k not in self._values_order: if k not in self._values_order:
self._values_order.append(k) self._values_order.append(k)
if k not in self.stateful_metrics: if k not in self.stateful_metrics:
# In the case that progress bar doesn't have a target value in the first # In the case that progress bar doesn't have a target value in the first
# epoch, both on_batch_end and on_epoch_end will be called, which will # epoch, both on_batch_end and on_epoch_end will be called, which will
# cause 'current' and 'self._seen_so_far' to have the same value. Force # cause 'current' and 'self._seen_so_far' to have the same value. Force
# the minimal value to 1 here, otherwise stateful_metric will be 0s. # the minimal value to 1 here, otherwise stateful_metric will be 0s.
value_base = max(current - self._seen_so_far, 1) value_base = max(current - self._seen_so_far, 1)
if k not in self._values: if k not in self._values:
self._values[k] = [v * value_base, value_base] self._values[k] = [v * value_base, value_base]
else: else:
self._values[k][0] += v * value_base self._values[k][0] += v * value_base
self._values[k][1] += value_base self._values[k][1] += value_base
else: else:
# Stateful metrics output a numeric value. This representation # Stateful metrics output a numeric value. This representation
# means "take an average from a single value" but keeps the # means "take an average from a single value" but keeps the
# numeric formatting. # numeric formatting.
self._values[k] = [v, 1] self._values[k] = [v, 1]
self._seen_so_far = current self._seen_so_far = current
now = time.time() now = time.time()
info = ' - %.0fs' % (now - self._start) info = ' - %.0fs' % (now - self._start)
if self.verbose == 1: if self.verbose == 1:
if now - self._last_update < self.interval and not finalize: if now - self._last_update < self.interval and not finalize:
return return
prev_total_width = self._total_width prev_total_width = self._total_width
if self._dynamic_display: if self._dynamic_display:
sys.stdout.write('\b' * prev_total_width) sys.stdout.write('\b' * prev_total_width)
sys.stdout.write('\r') sys.stdout.write('\r')
else: else:
sys.stdout.write('\n') sys.stdout.write('\n')
if self.target is not None: if self.target is not None:
numdigits = int(np.log10(self.target)) + 1 numdigits = int(np.log10(self.target)) + 1
bar = ('%' + str(numdigits) + 'd/%d [') % (current, self.target) bar = ('%' + str(numdigits) + 'd/%d [') % (current, self.target)
prog = float(current) / self.target prog = float(current) / self.target
prog_width = int(self.width * prog) prog_width = int(self.width * prog)
if prog_width > 0: if prog_width > 0:
bar += ('=' * (prog_width - 1)) bar += ('=' * (prog_width - 1))
if current < self.target: if current < self.target:
bar += '>' bar += '>'
else: else:
bar += '=' bar += '='
bar += ('.' * (self.width - prog_width)) bar += ('.' * (self.width - prog_width))
bar += ']' bar += ']'
else: else:
bar = '%7d/Unknown' % current bar = '%7d/Unknown' % current
self._total_width = len(bar) self._total_width = len(bar)
sys.stdout.write(bar) sys.stdout.write(bar)
if current: if current:
time_per_unit = (now - self._start) / current time_per_unit = (now - self._start) / current
else: else:
time_per_unit = 0 time_per_unit = 0
if self.target is None or finalize: if self.target is None or finalize:
if time_per_unit >= 1 or time_per_unit == 0: if time_per_unit >= 1 or time_per_unit == 0:
info += ' %.0fs/%s' % (time_per_unit, self.unit_name) info += ' %.0fs/%s' % (time_per_unit, self.unit_name)
elif time_per_unit >= 1e-3: elif time_per_unit >= 1e-3:
info += ' %.0fms/%s' % (time_per_unit * 1e3, self.unit_name) info += ' %.0fms/%s' % (time_per_unit * 1e3, self.unit_name)
else: else:
info += ' %.0fus/%s' % (time_per_unit * 1e6, self.unit_name) info += ' %.0fus/%s' % (time_per_unit * 1e6, self.unit_name)
else: else:
eta = time_per_unit * (self.target - current) eta = time_per_unit * (self.target - current)
if eta > 3600: if eta > 3600:
eta_format = '%d:%02d:%02d' % (eta // 3600, eta_format = '%d:%02d:%02d' % (eta // 3600,
(eta % 3600) // 60, eta % 60) (eta % 3600) // 60, eta % 60)
elif eta > 60: elif eta > 60:
eta_format = '%d:%02d' % (eta // 60, eta % 60) eta_format = '%d:%02d' % (eta // 60, eta % 60)
else: else:
eta_format = '%ds' % eta eta_format = '%ds' % eta
info = ' - ETA: %s' % eta_format info = ' - ETA: %s' % eta_format
for k in self._values_order: for k in self._values_order:
info += ' - %s:' % k info += ' - %s:' % k
if isinstance(self._values[k], list): if isinstance(self._values[k], list):
avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) avg = np.mean(
if abs(avg) > 1e-3: self._values[k][0] / max(1, self._values[k][1]))
info += ' %.4f' % avg if abs(avg) > 1e-3:
else: info += ' %.4f' % avg
info += ' %.4e' % avg else:
else: info += ' %.4e' % avg
info += ' %s' % self._values[k] else:
info += ' %s' % self._values[k]
self._total_width += len(info)
if prev_total_width > self._total_width: self._total_width += len(info)
info += (' ' * (prev_total_width - self._total_width)) if prev_total_width > self._total_width:
info += (' ' * (prev_total_width - self._total_width))
if finalize:
info += '\n' if finalize:
info += '\n'
sys.stdout.write(info)
sys.stdout.flush() sys.stdout.write(info)
sys.stdout.flush()
elif self.verbose == 2:
if finalize: elif self.verbose == 2:
numdigits = int(np.log10(self.target)) + 1 if finalize:
count = ('%' + str(numdigits) + 'd/%d') % (current, self.target) numdigits = int(np.log10(self.target)) + 1
info = count + info count = ('%' + str(numdigits) + 'd/%d') % (current, self.target)
for k in self._values_order: info = count + info
info += ' - %s:' % k for k in self._values_order:
avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) info += ' - %s:' % k
if avg > 1e-3: avg = np.mean(
info += ' %.4f' % avg self._values[k][0] / max(1, self._values[k][1]))
else: if avg > 1e-3:
info += ' %.4e' % avg info += ' %.4f' % avg
info += '\n' else:
info += ' %.4e' % avg
sys.stdout.write(info) info += '\n'
sys.stdout.flush()
sys.stdout.write(info)
self._last_update = now sys.stdout.flush()
def add(self, n, values=None): self._last_update = now
self.update(self._seen_so_far + n, values)
\ No newline at end of file def add(self, n, values=None):
self.update(self._seen_so_far + n, values)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册