Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
4b5665d0
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
285
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4b5665d0
编写于
9月 22, 2020
作者:
M
michaelowenliu
浏览文件
操作
浏览文件
下载
差异文件
add ocrnet
上级
8e0e4e39
23d69271
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
544 addition
and
302 deletion
+544
-302
dygraph/configs/_base_/cityscapes.yml
dygraph/configs/_base_/cityscapes.yml
+2
-2
dygraph/paddleseg/core/seg_train.py
dygraph/paddleseg/core/seg_train.py
+8
-5
dygraph/paddleseg/core/val.py
dygraph/paddleseg/core/val.py
+1
-1
dygraph/paddleseg/cvlibs/callbacks.py
dygraph/paddleseg/cvlibs/callbacks.py
+29
-29
dygraph/paddleseg/models/danet.py
dygraph/paddleseg/models/danet.py
+217
-0
dygraph/paddleseg/models/ocrnet.py
dygraph/paddleseg/models/ocrnet.py
+123
-104
dygraph/paddleseg/utils/metrics.py
dygraph/paddleseg/utils/metrics.py
+1
-1
dygraph/paddleseg/utils/progbar.py
dygraph/paddleseg/utils/progbar.py
+163
-160
未找到文件。
dygraph/configs/_base_/cityscapes.yml
浏览文件 @
4b5665d0
...
...
@@ -4,7 +4,7 @@ learning_rate: 0.01
train_dataset
:
type
:
Cityscapes
dataset_root
:
data
/cityscapes
dataset_root
:
/mnt/liuyi22/.cache/paddle/dataset
/cityscapes
transforms
:
-
type
:
ResizeStepScaling
min_scale_factor
:
0.5
...
...
@@ -18,7 +18,7 @@ train_dataset:
val_dataset
:
type
:
Cityscapes
dataset_root
:
data
/cityscapes
dataset_root
:
/mnt/liuyi22/.cache/paddle/dataset
/cityscapes
transforms
:
-
type
:
Normalize
mode
:
val
...
...
dygraph/paddleseg/core/seg_train.py
浏览文件 @
4b5665d0
...
...
@@ -87,7 +87,8 @@ def seg_train(model,
out_labels
=
[
"loss"
,
"reader_cost"
,
"batch_cost"
]
base_logger
=
callbacks
.
BaseLogger
(
period
=
log_iters
)
train_logger
=
callbacks
.
TrainLogger
(
log_freq
=
log_iters
)
model_ckpt
=
callbacks
.
ModelCheckpoint
(
save_dir
,
save_params_only
=
False
,
period
=
save_interval_iters
)
model_ckpt
=
callbacks
.
ModelCheckpoint
(
save_dir
,
save_params_only
=
False
,
period
=
save_interval_iters
)
vdl
=
callbacks
.
VisualDL
(
log_dir
=
os
.
path
.
join
(
save_dir
,
"log"
))
cbks_list
=
[
base_logger
,
train_logger
,
model_ckpt
,
vdl
]
...
...
@@ -120,7 +121,7 @@ def seg_train(model,
iter
+=
1
if
iter
>
iters
:
break
logs
[
"reader_cost"
]
=
timer
.
elapsed_time
()
############## 2 ################
cbks
.
on_iter_begin
(
iter
,
logs
)
...
...
@@ -136,7 +137,7 @@ def seg_train(model,
loss
=
ddp_model
.
scale_loss
(
loss
)
loss
.
backward
()
ddp_model
.
apply_collective_grads
()
else
:
logits
=
model
(
images
)
loss
=
loss_computation
(
logits
,
labels
,
losses
)
...
...
@@ -148,7 +149,7 @@ def seg_train(model,
model
.
clear_gradients
()
logs
[
'loss'
]
=
loss
.
numpy
()[
0
]
logs
[
"batch_cost"
]
=
timer
.
elapsed_time
()
############## 3 ################
...
...
@@ -159,4 +160,6 @@ def seg_train(model,
############### 4 ###############
cbks
.
on_train_end
(
logs
)
#################################
\ No newline at end of file
#################################
dygraph/paddleseg/core/val.py
浏览文件 @
4b5665d0
...
...
@@ -67,7 +67,7 @@ def evaluate(model,
pred
=
pred
[
np
.
newaxis
,
:,
:,
np
.
newaxis
]
pred
=
pred
.
astype
(
'int64'
)
mask
=
label
!=
ignore_index
# To-DO Test Execution Time
conf_mat
.
calculate
(
pred
=
pred
,
label
=
label
,
ignore
=
mask
)
_
,
iou
=
conf_mat
.
mean_iou
()
...
...
dygraph/paddleseg/cvlibs/callbacks.py
浏览文件 @
4b5665d0
...
...
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
time
...
...
@@ -24,6 +23,7 @@ from visualdl import LogWriter
from
paddleseg.utils.progbar
import
Progbar
import
paddleseg.utils.logger
as
logger
class
CallbackList
(
object
):
"""Container abstracting a list of callbacks.
# Arguments
...
...
@@ -44,7 +44,7 @@ class CallbackList(object):
def
set_model
(
self
,
model
):
for
callback
in
self
.
callbacks
:
callback
.
set_model
(
model
)
def
set_optimizer
(
self
,
optimizer
):
for
callback
in
self
.
callbacks
:
callback
.
set_optimizer
(
optimizer
)
...
...
@@ -82,6 +82,7 @@ class CallbackList(object):
def
__iter__
(
self
):
return
iter
(
self
.
callbacks
)
class
Callback
(
object
):
"""Abstract base class used to build new callbacks.
"""
...
...
@@ -94,7 +95,7 @@ class Callback(object):
def
set_model
(
self
,
model
):
self
.
model
=
model
def
set_optimizer
(
self
,
optimizer
):
self
.
optimizer
=
optimizer
...
...
@@ -110,18 +111,18 @@ class Callback(object):
def
on_train_end
(
self
,
logs
=
None
):
pass
class
BaseLogger
(
Callback
):
class
BaseLogger
(
Callback
):
def
__init__
(
self
,
period
=
10
):
super
(
BaseLogger
,
self
).
__init__
()
self
.
period
=
period
def
_reset
(
self
):
self
.
totals
=
{}
def
on_train_begin
(
self
,
logs
=
None
):
self
.
totals
=
{}
def
on_iter_end
(
self
,
iter
,
logs
=
None
):
logs
=
logs
or
{}
#(iter - 1) // iters_per_epoch + 1
...
...
@@ -132,13 +133,13 @@ class BaseLogger(Callback):
self
.
totals
[
k
]
=
v
if
iter
%
self
.
period
==
0
and
ParallelEnv
().
local_rank
==
0
:
for
k
in
self
.
totals
:
logs
[
k
]
=
self
.
totals
[
k
]
/
self
.
period
self
.
_reset
()
class
TrainLogger
(
Callback
):
class
TrainLogger
(
Callback
):
def
__init__
(
self
,
log_freq
=
10
):
self
.
log_freq
=
log_freq
...
...
@@ -154,7 +155,7 @@ class TrainLogger(Callback):
return
result
.
format
(
*
arr
)
def
on_iter_end
(
self
,
iter
,
logs
=
None
):
if
iter
%
self
.
log_freq
==
0
and
ParallelEnv
().
local_rank
==
0
:
total_iters
=
self
.
params
[
"total_iters"
]
iters_per_epoch
=
self
.
params
[
"iters_per_epoch"
]
...
...
@@ -167,49 +168,50 @@ class TrainLogger(Callback):
reader_cost
=
logs
[
"reader_cost"
]
logger
.
info
(
"[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.4f} | ETA {}"
.
format
(
current_epoch
,
iter
,
total_iters
,
loss
,
lr
,
batch_cost
,
reader_cost
,
eta
))
"[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.4f} | ETA {}"
.
format
(
current_epoch
,
iter
,
total_iters
,
loss
,
lr
,
batch_cost
,
reader_cost
,
eta
))
class
ProgbarLogger
(
Callback
):
class
ProgbarLogger
(
Callback
):
def
__init__
(
self
):
super
(
ProgbarLogger
,
self
).
__init__
()
def
on_train_begin
(
self
,
logs
=
None
):
self
.
verbose
=
self
.
params
[
"verbose"
]
self
.
total_iters
=
self
.
params
[
"total_iters"
]
self
.
target
=
self
.
params
[
"total_iters"
]
self
.
target
=
self
.
params
[
"total_iters"
]
self
.
progbar
=
Progbar
(
target
=
self
.
target
,
verbose
=
self
.
verbose
)
self
.
seen
=
0
self
.
log_values
=
[]
def
on_iter_begin
(
self
,
iter
,
logs
=
None
):
#self.seen = 0
if
self
.
seen
<
self
.
target
:
self
.
log_values
=
[]
def
on_iter_end
(
self
,
iter
,
logs
=
None
):
logs
=
logs
or
{}
self
.
seen
+=
1
for
k
in
self
.
params
[
'metrics'
]:
if
k
in
logs
:
self
.
log_values
.
append
((
k
,
logs
[
k
]))
#if self.verbose and self.seen < self.target and ParallelEnv.local_rank == 0:
#print(self.log_values)
#print(self.log_values)
if
self
.
seen
<
self
.
target
:
self
.
progbar
.
update
(
self
.
seen
,
self
.
log_values
)
class
ModelCheckpoint
(
Callback
):
def
__init__
(
self
,
save_dir
,
monitor
=
"miou"
,
save_best_only
=
False
,
save_params_only
=
True
,
mode
=
"max"
,
period
=
1
):
def
__init__
(
self
,
save_dir
,
monitor
=
"miou"
,
save_best_only
=
False
,
save_params_only
=
True
,
mode
=
"max"
,
period
=
1
):
super
(
ModelCheckpoint
,
self
).
__init__
()
self
.
monitor
=
monitor
self
.
save_dir
=
save_dir
...
...
@@ -241,7 +243,7 @@ class ModelCheckpoint(Callback):
current_save_dir
=
os
.
path
.
join
(
self
.
save_dir
,
"iter_{}"
.
format
(
iter
))
current_save_dir
=
os
.
path
.
abspath
(
current_save_dir
)
#if self.iters_since_last_save % self.period and ParallelEnv().local_rank == 0:
#self.iters_since_last_save = 0
#self.iters_since_last_save = 0
if
iter
%
self
.
period
==
0
and
ParallelEnv
().
local_rank
==
0
:
if
self
.
verbose
>
0
:
print
(
"iter {iter_num}: saving model to {path}"
.
format
(
...
...
@@ -252,11 +254,9 @@ class ModelCheckpoint(Callback):
if
not
self
.
save_params_only
:
paddle
.
save
(
self
.
optimizer
.
state_dict
(),
filepath
)
class
VisualDL
(
Callback
):
def
__init__
(
self
,
log_dir
=
"./log"
,
freq
=
1
):
super
(
VisualDL
,
self
).
__init__
()
self
.
log_dir
=
log_dir
...
...
@@ -274,4 +274,4 @@ class VisualDL(Callback):
self
.
writer
.
flush
()
def
on_train_end
(
self
,
logs
=
None
):
self
.
writer
.
close
()
\ No newline at end of file
self
.
writer
.
close
()
dygraph/paddleseg/models/danet.py
0 → 100644
浏览文件 @
4b5665d0
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddleseg.utils
import
utils
from
paddleseg.cvlibs
import
manager
,
param_init
from
paddleseg.models.common.layer_libs
import
ConvBNReLU
class
PAM
(
nn
.
Layer
):
"""Position attention module"""
def
__init__
(
self
,
in_channels
):
super
(
PAM
,
self
).
__init__
()
mid_channels
=
in_channels
//
8
self
.
query_conv
=
nn
.
Conv2d
(
in_channels
,
mid_channels
,
1
,
1
)
self
.
key_conv
=
nn
.
Conv2d
(
in_channels
,
mid_channels
,
1
,
1
)
self
.
value_conv
=
nn
.
Conv2d
(
in_channels
,
in_channels
,
1
,
1
)
self
.
gamma
=
self
.
create_parameter
(
shape
=
[
1
],
dtype
=
'float32'
,
default_initializer
=
nn
.
initializer
.
Constant
(
0
))
def
forward
(
self
,
x
):
n
,
_
,
h
,
w
=
x
.
shape
# query: n, h * w, c1
query
=
self
.
query_conv
(
x
)
query
=
paddle
.
reshape
(
query
,
(
n
,
-
1
,
h
*
w
))
query
=
paddle
.
transpose
(
query
,
(
0
,
2
,
1
))
# key: n, c1, h * w
key
=
self
.
key_conv
(
x
)
key
=
paddle
.
reshape
(
key
,
(
n
,
-
1
,
h
*
w
))
# sim: n, h * w, h * w
sim
=
paddle
.
bmm
(
query
,
key
)
sim
=
F
.
softmax
(
sim
,
axis
=-
1
)
value
=
self
.
value_conv
(
x
)
value
=
paddle
.
reshape
(
value
,
(
n
,
-
1
,
h
*
w
))
sim
=
paddle
.
transpose
(
sim
,
(
0
,
2
,
1
))
# feat: from (n, c2, h * w) -> (n, c2, h, w)
feat
=
paddle
.
bmm
(
value
,
sim
)
feat
=
paddle
.
reshape
(
feat
,
(
n
,
-
1
,
h
,
w
))
out
=
self
.
gamma
*
feat
+
x
return
out
class
CAM
(
nn
.
Layer
):
"""Channel attention module"""
def
__init__
(
self
):
super
(
CAM
,
self
).
__init__
()
self
.
gamma
=
self
.
create_parameter
(
shape
=
[
1
],
dtype
=
'float32'
,
default_initializer
=
nn
.
initializer
.
Constant
(
0
))
def
forward
(
self
,
x
):
n
,
c
,
h
,
w
=
x
.
shape
# query: n, c, h * w
query
=
paddle
.
reshape
(
x
,
(
n
,
c
,
h
*
w
))
# key: n, h * w, c
key
=
paddle
.
reshape
(
x
,
(
n
,
c
,
h
*
w
))
key
=
paddle
.
transpose
(
key
,
(
0
,
2
,
1
))
# sim: n, c, c
sim
=
paddle
.
bmm
(
query
,
key
)
# The danet author claims that this can avoid gradient divergence
sim
=
paddle
.
max
(
sim
,
axis
=-
1
,
keepdim
=
True
).
expand_as
(
sim
)
-
sim
sim
=
F
.
softmax
(
sim
,
axis
=-
1
)
# feat: from (n, c, h * w) to (n, c, h, w)
value
=
paddle
.
reshape
(
x
,
(
n
,
c
,
h
*
w
))
feat
=
paddle
.
bmm
(
sim
,
value
)
feat
=
paddle
.
reshape
(
feat
,
(
n
,
c
,
h
,
w
))
out
=
self
.
gamma
*
feat
+
x
return
out
class
DAHead
(
nn
.
Layer
):
"""
The Dual attention head.
Args:
num_classes(int): the unique number of target classes.
in_channels(tuple): the number of input channels.
"""
def
__init__
(
self
,
num_classes
,
in_channels
=
None
):
super
(
DAHead
,
self
).
__init__
()
in_channels
=
in_channels
[
-
1
]
inter_channels
=
in_channels
//
4
self
.
channel_conv
=
ConvBNReLU
(
in_channels
,
inter_channels
,
3
,
padding
=
1
)
self
.
position_conv
=
ConvBNReLU
(
in_channels
,
inter_channels
,
3
,
padding
=
1
)
self
.
pam
=
PAM
(
inter_channels
)
self
.
cam
=
CAM
()
self
.
conv1
=
ConvBNReLU
(
inter_channels
,
inter_channels
,
3
,
padding
=
1
)
self
.
conv2
=
ConvBNReLU
(
inter_channels
,
inter_channels
,
3
,
padding
=
1
)
self
.
aux_head_pam
=
nn
.
Sequential
(
nn
.
Dropout2d
(
0.1
),
nn
.
Conv2d
(
inter_channels
,
num_classes
,
1
))
self
.
aux_head_cam
=
nn
.
Sequential
(
nn
.
Dropout2d
(
0.1
),
nn
.
Conv2d
(
inter_channels
,
num_classes
,
1
))
self
.
cls_head
=
nn
.
Sequential
(
nn
.
Dropout2d
(
0.1
),
nn
.
Conv2d
(
inter_channels
,
num_classes
,
1
))
self
.
init_weight
()
def
forward
(
self
,
x
,
label
=
None
):
feats
=
x
[
-
1
]
channel_feats
=
self
.
channel_conv
(
feats
)
channel_feats
=
self
.
cam
(
channel_feats
)
channel_feats
=
self
.
conv1
(
channel_feats
)
cam_head
=
self
.
aux_head_cam
(
channel_feats
)
position_feats
=
self
.
position_conv
(
feats
)
position_feats
=
self
.
pam
(
position_feats
)
position_feats
=
self
.
conv2
(
position_feats
)
pam_head
=
self
.
aux_head_pam
(
position_feats
)
feats_sum
=
position_feats
+
channel_feats
cam_logit
=
self
.
aux_head_cam
(
channel_feats
)
pam_logit
=
self
.
aux_head_cam
(
position_feats
)
logit
=
self
.
cls_head
(
feats_sum
)
return
[
logit
,
cam_logit
,
pam_logit
]
def
init_weight
(
self
):
"""Initialize the parameters of model parts."""
for
sublayer
in
self
.
sublayers
():
if
isinstance
(
sublayer
,
nn
.
Conv2d
):
param_init
.
normal_init
(
sublayer
.
weight
,
scale
=
0.001
)
elif
isinstance
(
sublayer
,
nn
.
SyncBatchNorm
):
param_init
.
constant_init
(
sublayer
.
weight
,
value
=
1
)
param_init
.
constant_init
(
sublayer
.
bias
,
value
=
0
)
@
manager
.
MODELS
.
add_component
class
DANet
(
nn
.
Layer
):
"""
The DANet implementation based on PaddlePaddle.
The original article refers to
Fu, jun, et al. "Dual Attention Network for Scene Segmentation"
(https://arxiv.org/pdf/1809.02983.pdf)
Args:
num_classes(int): the unique number of target classes.
backbone(Paddle.nn.Layer): backbone network.
pretrained(str): the path or url of pretrained model. Default to None.
backbone_indices(tuple): values in the tuple indicate the indices of output of backbone.
Only the last indice is used.
"""
def
__init__
(
self
,
num_classes
,
backbone
,
pretrained
=
None
,
backbone_indices
=
None
):
super
(
DANet
,
self
).
__init__
()
self
.
backbone
=
backbone
self
.
backbone_indices
=
backbone_indices
in_channels
=
[
self
.
backbone
.
channels
[
i
]
for
i
in
backbone_indices
]
self
.
head
=
DAHead
(
num_classes
=
num_classes
,
in_channels
=
in_channels
)
self
.
init_weight
(
pretrained
)
def
forward
(
self
,
x
,
label
=
None
):
feats
=
self
.
backbone
(
x
)
feats
=
[
feats
[
i
]
for
i
in
self
.
backbone_indices
]
preds
=
self
.
head
(
feats
,
label
)
preds
=
[
F
.
resize_bilinear
(
pred
,
x
.
shape
[
2
:])
for
pred
in
preds
]
return
preds
def
init_weight
(
self
,
pretrained
=
None
):
"""
Initialize the parameters of model parts.
Args:
pretrained ([str], optional): the path of pretrained model.. Defaults to None.
"""
if
pretrained
is
not
None
:
if
os
.
path
.
exists
(
pretrained
):
utils
.
load_pretrained_model
(
self
,
pretrained
)
else
:
raise
Exception
(
'Pretrained model is not found: {}'
.
format
(
pretrained
))
dygraph/paddleseg/models/ocrnet.py
浏览文件 @
4b5665d0
...
...
@@ -14,36 +14,41 @@
import
os
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph
import
Sequential
,
Conv2D
import
paddle
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddleseg.cvlibs
import
manager
from
paddleseg.models.common.layer_libs
import
ConvBNReLU
from
paddleseg
import
utils
from
paddleseg.cvlibs
import
manager
,
param_init
from
paddleseg.models.common.layer_libs
import
ConvBNReLU
,
AuxLayer
class
SpatialGatherBlock
(
fluid
.
dygraph
.
Layer
):
class
SpatialGatherBlock
(
nn
.
Layer
):
"""Aggregation layer to compute the pixel-region representation"""
def
forward
(
self
,
pixels
,
regions
):
n
,
c
,
h
,
w
=
pixels
.
shape
_
,
k
,
_
,
_
=
regions
.
shape
# pixels: from (n, c, h, w) to (n, h*w, c)
pixels
=
fluid
.
layers
.
reshape
(
pixels
,
(
n
,
c
,
h
*
w
))
pixels
=
fluid
.
layers
.
transpose
(
pixels
,
(
0
,
2
,
1
))
pixels
=
paddle
.
reshape
(
pixels
,
(
n
,
c
,
h
*
w
))
pixels
=
paddle
.
transpose
(
pixels
,
(
0
,
2
,
1
))
# regions: from (n, k, h, w) to (n, k, h*w)
regions
=
fluid
.
layers
.
reshape
(
regions
,
(
n
,
k
,
h
*
w
))
regions
=
fluid
.
layers
.
softmax
(
regions
,
axis
=
2
)
regions
=
paddle
.
reshape
(
regions
,
(
n
,
k
,
h
*
w
))
regions
=
F
.
softmax
(
regions
,
axis
=
2
)
# feats: from (n, k, c) to (n, c, k, 1)
feats
=
fluid
.
layers
.
matmul
(
regions
,
pixels
)
feats
=
fluid
.
layers
.
transpose
(
feats
,
(
0
,
2
,
1
))
feats
=
fluid
.
layers
.
unsqueeze
(
feats
,
axes
=
[
-
1
]
)
feats
=
paddle
.
bmm
(
regions
,
pixels
)
feats
=
paddle
.
transpose
(
feats
,
(
0
,
2
,
1
))
feats
=
paddle
.
unsqueeze
(
feats
,
axis
=-
1
)
return
feats
class
SpatialOCRModule
(
fluid
.
dygraph
.
Layer
):
class
SpatialOCRModule
(
nn
.
Layer
):
"""Aggregate the global object representation to update the representation for each pixel"""
def
__init__
(
self
,
in_channels
,
key_channels
,
...
...
@@ -53,30 +58,31 @@ class SpatialOCRModule(fluid.dygraph.Layer):
self
.
attention_block
=
ObjectAttentionBlock
(
in_channels
,
key_channels
)
self
.
dropout_rate
=
dropout_rate
self
.
conv1x1
=
Conv2D
(
2
*
in_channels
,
out_channels
,
1
)
self
.
conv1x1
=
nn
.
Sequential
(
nn
.
Conv2d
(
2
*
in_channels
,
out_channels
,
1
),
nn
.
Dropout2d
(
0.1
))
def
forward
(
self
,
pixels
,
regions
):
context
=
self
.
attention_block
(
pixels
,
regions
)
feats
=
fluid
.
layers
.
concat
([
context
,
pixels
],
axis
=
1
)
feats
=
paddle
.
concat
([
context
,
pixels
],
axis
=
1
)
feats
=
self
.
conv1x1
(
feats
)
feats
=
fluid
.
layers
.
dropout
(
feats
,
self
.
dropout_rate
)
return
feats
class
ObjectAttentionBlock
(
fluid
.
dygraph
.
Layer
):
class
ObjectAttentionBlock
(
nn
.
Layer
):
"""A self-attention module."""
def
__init__
(
self
,
in_channels
,
key_channels
):
super
(
ObjectAttentionBlock
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
key_channels
=
key_channels
self
.
f_pixel
=
Sequential
(
self
.
f_pixel
=
nn
.
Sequential
(
ConvBNReLU
(
in_channels
,
key_channels
,
1
),
ConvBNReLU
(
key_channels
,
key_channels
,
1
))
self
.
f_object
=
Sequential
(
self
.
f_object
=
nn
.
Sequential
(
ConvBNReLU
(
in_channels
,
key_channels
,
1
),
ConvBNReLU
(
key_channels
,
key_channels
,
1
))
...
...
@@ -89,127 +95,140 @@ class ObjectAttentionBlock(fluid.dygraph.Layer):
# query : from (n, c1, h1, w1) to (n, h1*w1, key_channels)
query
=
self
.
f_pixel
(
x
)
query
=
fluid
.
layers
.
reshape
(
query
,
(
n
,
self
.
key_channels
,
-
1
))
query
=
fluid
.
layers
.
transpose
(
query
,
(
0
,
2
,
1
))
query
=
paddle
.
reshape
(
query
,
(
n
,
self
.
key_channels
,
-
1
))
query
=
paddle
.
transpose
(
query
,
(
0
,
2
,
1
))
# key : from (n, c2, h2, w2) to (n, key_channels, h2*w2)
key
=
self
.
f_object
(
proxy
)
key
=
fluid
.
layers
.
reshape
(
key
,
(
n
,
self
.
key_channels
,
-
1
))
key
=
paddle
.
reshape
(
key
,
(
n
,
self
.
key_channels
,
-
1
))
# value : from (n, c2, h2, w2) to (n, h2*w2, key_channels)
value
=
self
.
f_down
(
proxy
)
value
=
fluid
.
layers
.
reshape
(
value
,
(
n
,
self
.
key_channels
,
-
1
))
value
=
fluid
.
layers
.
transpose
(
value
,
(
0
,
2
,
1
))
value
=
paddle
.
reshape
(
value
,
(
n
,
self
.
key_channels
,
-
1
))
value
=
paddle
.
transpose
(
value
,
(
0
,
2
,
1
))
# sim_map (n, h1*w1, h2*w2)
sim_map
=
fluid
.
layers
.
matmul
(
query
,
key
)
sim_map
=
paddle
.
bmm
(
query
,
key
)
sim_map
=
(
self
.
key_channels
**-
.
5
)
*
sim_map
sim_map
=
fluid
.
layers
.
softmax
(
sim_map
,
axis
=-
1
)
sim_map
=
F
.
softmax
(
sim_map
,
axis
=-
1
)
# context from (n, h1*w1, key_channels) to (n , out_channels, h1, w1)
context
=
fluid
.
layers
.
matmul
(
sim_map
,
value
)
context
=
fluid
.
layers
.
transpose
(
context
,
(
0
,
2
,
1
))
context
=
fluid
.
layers
.
reshape
(
context
,
(
n
,
self
.
key_channels
,
h
,
w
))
context
=
paddle
.
bmm
(
sim_map
,
value
)
context
=
paddle
.
transpose
(
context
,
(
0
,
2
,
1
))
context
=
paddle
.
reshape
(
context
,
(
n
,
self
.
key_channels
,
h
,
w
))
context
=
self
.
f_up
(
context
)
return
context
@
manager
.
MODELS
.
add_component
class
OCRNet
(
fluid
.
dygraph
.
Layer
):
class
OCRHead
(
nn
.
Layer
):
"""
The Object contextual representation head.
Args:
num_classes(int): the unique number of target classes.
in_channels(tuple): the number of input channels.
ocr_mid_channels(int): the number of middle channels in OCRHead.
ocr_key_channels(int): the number of key channels in ObjectAttentionBlock.
"""
def
__init__
(
self
,
num_classes
,
backbone
,
model_pretrained
=
None
,
in_channels
=
None
,
ocr_mid_channels
=
512
,
ocr_key_channels
=
256
,
ignore_index
=
255
):
super
(
OCRNet
,
self
).
__init__
()
ocr_key_channels
=
256
):
super
(
OCRHead
,
self
).
__init__
()
self
.
ignore_index
=
ignore_index
self
.
num_classes
=
num_classes
self
.
EPS
=
1e-5
self
.
backbone
=
backbone
self
.
spatial_gather
=
SpatialGatherBlock
()
self
.
spatial_ocr
=
SpatialOCRModule
(
ocr_mid_channels
,
ocr_key_channels
,
ocr_mid_channels
)
self
.
conv3x3_ocr
=
ConvBNReLU
(
in_channels
,
ocr_mid_channels
,
3
,
padding
=
1
)
self
.
cls_head
=
Conv2D
(
ocr_mid_channels
,
self
.
num_classes
,
1
)
self
.
aux_head
=
Sequential
(
ConvBNReLU
(
in_channels
,
in_channels
,
3
,
padding
=
1
),
Conv2D
(
in_channels
,
self
.
num_classes
,
1
))
self
.
indices
=
[
-
2
,
-
1
]
if
len
(
in_channels
)
>
1
else
[
-
1
,
-
1
]
self
.
init_weight
(
model_pretrained
)
self
.
conv3x3_ocr
=
ConvBNReLU
(
in_channels
[
self
.
indices
[
1
]],
ocr_mid_channels
,
3
,
padding
=
1
)
self
.
cls_head
=
nn
.
Conv2d
(
ocr_mid_channels
,
self
.
num_classes
,
1
)
self
.
aux_head
=
AuxLayer
(
in_channels
[
self
.
indices
[
0
]],
in_channels
[
self
.
indices
[
0
]],
self
.
num_classes
)
self
.
init_weight
()
def
forward
(
self
,
x
,
label
=
None
):
feat
s
=
self
.
backbone
(
x
)
feat
_shallow
,
feat_deep
=
x
[
self
.
indices
[
0
]],
x
[
self
.
indices
[
1
]]
soft_regions
=
self
.
aux_head
(
feat
s
)
pixels
=
self
.
conv3x3_ocr
(
feat
s
)
soft_regions
=
self
.
aux_head
(
feat
_shallow
)
pixels
=
self
.
conv3x3_ocr
(
feat
_deep
)
object_regions
=
self
.
spatial_gather
(
pixels
,
soft_regions
)
ocr
=
self
.
spatial_ocr
(
pixels
,
object_regions
)
logit
=
self
.
cls_head
(
ocr
)
logit
=
fluid
.
layers
.
resize_bilinear
(
logit
,
x
.
shape
[
2
:])
if
self
.
training
:
soft_regions
=
fluid
.
layers
.
resize_bilinear
(
soft_regions
,
x
.
shape
[
2
:])
cls_loss
=
self
.
_get_loss
(
logit
,
label
)
aux_loss
=
self
.
_get_loss
(
soft_regions
,
label
)
return
cls_loss
+
0.4
*
aux_loss
score_map
=
fluid
.
layers
.
softmax
(
logit
,
axis
=
1
)
score_map
=
fluid
.
layers
.
transpose
(
score_map
,
[
0
,
2
,
3
,
1
])
pred
=
fluid
.
layers
.
argmax
(
score_map
,
axis
=
3
)
pred
=
fluid
.
layers
.
unsqueeze
(
pred
,
axes
=
[
3
])
return
pred
,
score_map
def
init_weight
(
self
,
pretrained_model
=
None
):
return
[
logit
,
soft_regions
]
def
init_weight
(
self
):
"""Initialize the parameters of model parts."""
for
sublayer
in
self
.
sublayers
():
if
isinstance
(
sublayer
,
nn
.
Conv2d
):
param_init
.
normal_init
(
sublayer
.
weight
,
scale
=
0.001
)
elif
isinstance
(
sublayer
,
nn
.
SyncBatchNorm
):
param_init
.
constant_init
(
sublayer
.
weight
,
value
=
1
)
param_init
.
constant_init
(
sublayer
.
bias
,
value
=
0
)
@
manager
.
MODELS
.
add_component
class
OCRNet
(
nn
.
Layer
):
"""
The OCRNet implementation based on PaddlePaddle.
The original article refers to
Yuan, Yuhui, et al. "Object-Contextual Representations for Semantic Segmentation"
(https://arxiv.org/pdf/1909.11065.pdf)
Args:
num_classes(int): the unique number of target classes.
backbone(Paddle.nn.Layer): backbone network.
pretrained(str): the path or url of pretrained model. Default to None.
backbone_indices(tuple): two values in the tuple indicate the indices of output of backbone.
the first index will be taken as a deep-supervision feature in auxiliary layer;
the second one will be taken as input of pixel representation.
ocr_mid_channels(int): the number of middle channels in OCRHead.
ocr_key_channels(int): the number of key channels in ObjectAttentionBlock.
"""
def
__init__
(
self
,
num_classes
,
backbone
,
pretrained
=
None
,
backbone_indices
=
None
,
ocr_mid_channels
=
512
,
ocr_key_channels
=
256
):
super
(
OCRNet
,
self
).
__init__
()
self
.
backbone
=
backbone
self
.
backbone_indices
=
backbone_indices
in_channels
=
[
self
.
backbone
.
channels
[
i
]
for
i
in
backbone_indices
]
self
.
head
=
OCRHead
(
num_classes
=
num_classes
,
in_channels
=
in_channels
,
ocr_mid_channels
=
ocr_mid_channels
,
ocr_key_channels
=
ocr_key_channels
)
self
.
init_weight
(
pretrained
)
def
forward
(
self
,
x
,
label
=
None
):
feats
=
self
.
backbone
(
x
)
feats
=
[
feats
[
i
]
for
i
in
self
.
backbone_indices
]
preds
=
self
.
head
(
feats
,
label
)
preds
=
[
F
.
resize_bilinear
(
pred
,
x
.
shape
[
2
:])
for
pred
in
preds
]
return
preds
def
init_weight
(
self
,
pretrained
=
None
):
"""
Initialize the parameters of model parts.
Args:
pretrained
_model
([str], optional): the path of pretrained model.. Defaults to None.
pretrained ([str], optional): the path of pretrained model.. Defaults to None.
"""
if
pretrained
_model
is
not
None
:
if
os
.
path
.
exists
(
pretrained
_model
):
utils
.
load_pretrained_model
(
self
,
pretrained
_model
)
if
pretrained
is
not
None
:
if
os
.
path
.
exists
(
pretrained
):
utils
.
load_pretrained_model
(
self
,
pretrained
)
else
:
raise
Exception
(
'Pretrained model is not found: {}'
.
format
(
pretrained_model
))
def
_get_loss
(
self
,
logit
,
label
):
"""
compute forward loss of the model
Args:
logit (tensor): the logit of model output
label (tensor): ground truth
Returns:
avg_loss (tensor): forward loss
"""
logit
=
fluid
.
layers
.
transpose
(
logit
,
[
0
,
2
,
3
,
1
])
label
=
fluid
.
layers
.
transpose
(
label
,
[
0
,
2
,
3
,
1
])
mask
=
label
!=
self
.
ignore_index
mask
=
fluid
.
layers
.
cast
(
mask
,
'float32'
)
loss
,
probs
=
fluid
.
layers
.
softmax_with_cross_entropy
(
logit
,
label
,
ignore_index
=
self
.
ignore_index
,
return_softmax
=
True
,
axis
=-
1
)
loss
=
loss
*
mask
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
/
(
fluid
.
layers
.
mean
(
mask
)
+
self
.
EPS
)
label
.
stop_gradient
=
True
mask
.
stop_gradient
=
True
return
avg_loss
raise
Exception
(
'Pretrained model is not found: {}'
.
format
(
pretrained
))
\ No newline at end of file
dygraph/paddleseg/utils/metrics.py
浏览文件 @
4b5665d0
...
...
@@ -41,7 +41,7 @@ class ConfusionMatrix(object):
label
=
np
.
asarray
(
label
)[
mask
]
pred
=
np
.
asarray
(
pred
)[
mask
]
one
=
np
.
ones_like
(
pred
)
# Accumuate ([row=label, col=pred], 1) into sparse
matrix
# Accumuate ([row=label, col=pred], 1) into sparse
spm
=
csr_matrix
((
one
,
(
label
,
pred
)),
shape
=
(
self
.
num_classes
,
self
.
num_classes
))
spm
=
spm
.
todense
()
...
...
dygraph/paddleseg/utils/progbar.py
浏览文件 @
4b5665d0
...
...
@@ -17,8 +17,9 @@ import time
import
numpy
as
np
class
Progbar
(
object
):
"""Displays a progress bar.
"""Displays a progress bar.
refers to https://github.com/keras-team/keras/blob/keras-2/keras/utils/generic_utils.py
Arguments:
target: Total number of steps expected, None if unknown.
...
...
@@ -31,39 +32,39 @@ class Progbar(object):
unit_name: Display name for step counts (usually "step" or "sample").
"""
def
__init__
(
self
,
target
,
width
=
30
,
verbose
=
1
,
interval
=
0.05
,
stateful_metrics
=
None
,
unit_name
=
'step'
):
self
.
target
=
target
self
.
width
=
width
self
.
verbose
=
verbose
self
.
interval
=
interval
self
.
unit_name
=
unit_name
if
stateful_metrics
:
self
.
stateful_metrics
=
set
(
stateful_metrics
)
else
:
self
.
stateful_metrics
=
set
()
self
.
_dynamic_display
=
((
hasattr
(
sys
.
stdout
,
'isatty'
)
and
sys
.
stdout
.
isatty
())
or
'ipykernel'
in
sys
.
modules
or
'posix'
in
sys
.
modules
or
'PYCHARM_HOSTED'
in
os
.
environ
)
self
.
_total_width
=
0
self
.
_seen_so_far
=
0
# We use a dict + list to avoid garbage collection
# issues found in OrderedDict
self
.
_values
=
{}
self
.
_values_order
=
[]
self
.
_start
=
time
.
time
()
self
.
_last_update
=
0
def
update
(
self
,
current
,
values
=
None
,
finalize
=
None
):
"""Updates the progress bar.
def
__init__
(
self
,
target
,
width
=
30
,
verbose
=
1
,
interval
=
0.05
,
stateful_metrics
=
None
,
unit_name
=
'step'
):
self
.
target
=
target
self
.
width
=
width
self
.
verbose
=
verbose
self
.
interval
=
interval
self
.
unit_name
=
unit_name
if
stateful_metrics
:
self
.
stateful_metrics
=
set
(
stateful_metrics
)
else
:
self
.
stateful_metrics
=
set
()
self
.
_dynamic_display
=
((
hasattr
(
sys
.
stdout
,
'isatty'
)
and
sys
.
stdout
.
isatty
())
or
'ipykernel'
in
sys
.
modules
or
'posix'
in
sys
.
modules
or
'PYCHARM_HOSTED'
in
os
.
environ
)
self
.
_total_width
=
0
self
.
_seen_so_far
=
0
# We use a dict + list to avoid garbage collection
# issues found in OrderedDict
self
.
_values
=
{}
self
.
_values_order
=
[]
self
.
_start
=
time
.
time
()
self
.
_last_update
=
0
def
update
(
self
,
current
,
values
=
None
,
finalize
=
None
):
"""Updates the progress bar.
Arguments:
current: Index of current step.
values: List of tuples: `(name, value_for_last_step)`. If `name` is in
...
...
@@ -72,129 +73,131 @@ class Progbar(object):
finalize: Whether this is the last update for the progress bar. If
`None`, defaults to `current >= self.target`.
"""
if
finalize
is
None
:
if
self
.
target
is
None
:
finalize
=
False
else
:
finalize
=
current
>=
self
.
target
values
=
values
or
[]
for
k
,
v
in
values
:
if
k
not
in
self
.
_values_order
:
self
.
_values_order
.
append
(
k
)
if
k
not
in
self
.
stateful_metrics
:
# In the case that progress bar doesn't have a target value in the first
# epoch, both on_batch_end and on_epoch_end will be called, which will
# cause 'current' and 'self._seen_so_far' to have the same value. Force
# the minimal value to 1 here, otherwise stateful_metric will be 0s.
value_base
=
max
(
current
-
self
.
_seen_so_far
,
1
)
if
k
not
in
self
.
_values
:
self
.
_values
[
k
]
=
[
v
*
value_base
,
value_base
]
else
:
self
.
_values
[
k
][
0
]
+=
v
*
value_base
self
.
_values
[
k
][
1
]
+=
value_base
else
:
# Stateful metrics output a numeric value. This representation
# means "take an average from a single value" but keeps the
# numeric formatting.
self
.
_values
[
k
]
=
[
v
,
1
]
self
.
_seen_so_far
=
current
now
=
time
.
time
()
info
=
' - %.0fs'
%
(
now
-
self
.
_start
)
if
self
.
verbose
==
1
:
if
now
-
self
.
_last_update
<
self
.
interval
and
not
finalize
:
return
prev_total_width
=
self
.
_total_width
if
self
.
_dynamic_display
:
sys
.
stdout
.
write
(
'
\b
'
*
prev_total_width
)
sys
.
stdout
.
write
(
'
\r
'
)
else
:
sys
.
stdout
.
write
(
'
\n
'
)
if
self
.
target
is
not
None
:
numdigits
=
int
(
np
.
log10
(
self
.
target
))
+
1
bar
=
(
'%'
+
str
(
numdigits
)
+
'd/%d ['
)
%
(
current
,
self
.
target
)
prog
=
float
(
current
)
/
self
.
target
prog_width
=
int
(
self
.
width
*
prog
)
if
prog_width
>
0
:
bar
+=
(
'='
*
(
prog_width
-
1
))
if
current
<
self
.
target
:
bar
+=
'>'
else
:
bar
+=
'='
bar
+=
(
'.'
*
(
self
.
width
-
prog_width
))
bar
+=
']'
else
:
bar
=
'%7d/Unknown'
%
current
self
.
_total_width
=
len
(
bar
)
sys
.
stdout
.
write
(
bar
)
if
current
:
time_per_unit
=
(
now
-
self
.
_start
)
/
current
else
:
time_per_unit
=
0
if
self
.
target
is
None
or
finalize
:
if
time_per_unit
>=
1
or
time_per_unit
==
0
:
info
+=
' %.0fs/%s'
%
(
time_per_unit
,
self
.
unit_name
)
elif
time_per_unit
>=
1e-3
:
info
+=
' %.0fms/%s'
%
(
time_per_unit
*
1e3
,
self
.
unit_name
)
else
:
info
+=
' %.0fus/%s'
%
(
time_per_unit
*
1e6
,
self
.
unit_name
)
else
:
eta
=
time_per_unit
*
(
self
.
target
-
current
)
if
eta
>
3600
:
eta_format
=
'%d:%02d:%02d'
%
(
eta
//
3600
,
(
eta
%
3600
)
//
60
,
eta
%
60
)
elif
eta
>
60
:
eta_format
=
'%d:%02d'
%
(
eta
//
60
,
eta
%
60
)
else
:
eta_format
=
'%ds'
%
eta
info
=
' - ETA: %s'
%
eta_format
for
k
in
self
.
_values_order
:
info
+=
' - %s:'
%
k
if
isinstance
(
self
.
_values
[
k
],
list
):
avg
=
np
.
mean
(
self
.
_values
[
k
][
0
]
/
max
(
1
,
self
.
_values
[
k
][
1
]))
if
abs
(
avg
)
>
1e-3
:
info
+=
' %.4f'
%
avg
else
:
info
+=
' %.4e'
%
avg
else
:
info
+=
' %s'
%
self
.
_values
[
k
]
self
.
_total_width
+=
len
(
info
)
if
prev_total_width
>
self
.
_total_width
:
info
+=
(
' '
*
(
prev_total_width
-
self
.
_total_width
))
if
finalize
:
info
+=
'
\n
'
sys
.
stdout
.
write
(
info
)
sys
.
stdout
.
flush
()
elif
self
.
verbose
==
2
:
if
finalize
:
numdigits
=
int
(
np
.
log10
(
self
.
target
))
+
1
count
=
(
'%'
+
str
(
numdigits
)
+
'd/%d'
)
%
(
current
,
self
.
target
)
info
=
count
+
info
for
k
in
self
.
_values_order
:
info
+=
' - %s:'
%
k
avg
=
np
.
mean
(
self
.
_values
[
k
][
0
]
/
max
(
1
,
self
.
_values
[
k
][
1
]))
if
avg
>
1e-3
:
info
+=
' %.4f'
%
avg
else
:
info
+=
' %.4e'
%
avg
info
+=
'
\n
'
sys
.
stdout
.
write
(
info
)
sys
.
stdout
.
flush
()
self
.
_last_update
=
now
def
add
(
self
,
n
,
values
=
None
):
self
.
update
(
self
.
_seen_so_far
+
n
,
values
)
\ No newline at end of file
if
finalize
is
None
:
if
self
.
target
is
None
:
finalize
=
False
else
:
finalize
=
current
>=
self
.
target
values
=
values
or
[]
for
k
,
v
in
values
:
if
k
not
in
self
.
_values_order
:
self
.
_values_order
.
append
(
k
)
if
k
not
in
self
.
stateful_metrics
:
# In the case that progress bar doesn't have a target value in the first
# epoch, both on_batch_end and on_epoch_end will be called, which will
# cause 'current' and 'self._seen_so_far' to have the same value. Force
# the minimal value to 1 here, otherwise stateful_metric will be 0s.
value_base
=
max
(
current
-
self
.
_seen_so_far
,
1
)
if
k
not
in
self
.
_values
:
self
.
_values
[
k
]
=
[
v
*
value_base
,
value_base
]
else
:
self
.
_values
[
k
][
0
]
+=
v
*
value_base
self
.
_values
[
k
][
1
]
+=
value_base
else
:
# Stateful metrics output a numeric value. This representation
# means "take an average from a single value" but keeps the
# numeric formatting.
self
.
_values
[
k
]
=
[
v
,
1
]
self
.
_seen_so_far
=
current
now
=
time
.
time
()
info
=
' - %.0fs'
%
(
now
-
self
.
_start
)
if
self
.
verbose
==
1
:
if
now
-
self
.
_last_update
<
self
.
interval
and
not
finalize
:
return
prev_total_width
=
self
.
_total_width
if
self
.
_dynamic_display
:
sys
.
stdout
.
write
(
'
\b
'
*
prev_total_width
)
sys
.
stdout
.
write
(
'
\r
'
)
else
:
sys
.
stdout
.
write
(
'
\n
'
)
if
self
.
target
is
not
None
:
numdigits
=
int
(
np
.
log10
(
self
.
target
))
+
1
bar
=
(
'%'
+
str
(
numdigits
)
+
'd/%d ['
)
%
(
current
,
self
.
target
)
prog
=
float
(
current
)
/
self
.
target
prog_width
=
int
(
self
.
width
*
prog
)
if
prog_width
>
0
:
bar
+=
(
'='
*
(
prog_width
-
1
))
if
current
<
self
.
target
:
bar
+=
'>'
else
:
bar
+=
'='
bar
+=
(
'.'
*
(
self
.
width
-
prog_width
))
bar
+=
']'
else
:
bar
=
'%7d/Unknown'
%
current
self
.
_total_width
=
len
(
bar
)
sys
.
stdout
.
write
(
bar
)
if
current
:
time_per_unit
=
(
now
-
self
.
_start
)
/
current
else
:
time_per_unit
=
0
if
self
.
target
is
None
or
finalize
:
if
time_per_unit
>=
1
or
time_per_unit
==
0
:
info
+=
' %.0fs/%s'
%
(
time_per_unit
,
self
.
unit_name
)
elif
time_per_unit
>=
1e-3
:
info
+=
' %.0fms/%s'
%
(
time_per_unit
*
1e3
,
self
.
unit_name
)
else
:
info
+=
' %.0fus/%s'
%
(
time_per_unit
*
1e6
,
self
.
unit_name
)
else
:
eta
=
time_per_unit
*
(
self
.
target
-
current
)
if
eta
>
3600
:
eta_format
=
'%d:%02d:%02d'
%
(
eta
//
3600
,
(
eta
%
3600
)
//
60
,
eta
%
60
)
elif
eta
>
60
:
eta_format
=
'%d:%02d'
%
(
eta
//
60
,
eta
%
60
)
else
:
eta_format
=
'%ds'
%
eta
info
=
' - ETA: %s'
%
eta_format
for
k
in
self
.
_values_order
:
info
+=
' - %s:'
%
k
if
isinstance
(
self
.
_values
[
k
],
list
):
avg
=
np
.
mean
(
self
.
_values
[
k
][
0
]
/
max
(
1
,
self
.
_values
[
k
][
1
]))
if
abs
(
avg
)
>
1e-3
:
info
+=
' %.4f'
%
avg
else
:
info
+=
' %.4e'
%
avg
else
:
info
+=
' %s'
%
self
.
_values
[
k
]
self
.
_total_width
+=
len
(
info
)
if
prev_total_width
>
self
.
_total_width
:
info
+=
(
' '
*
(
prev_total_width
-
self
.
_total_width
))
if
finalize
:
info
+=
'
\n
'
sys
.
stdout
.
write
(
info
)
sys
.
stdout
.
flush
()
elif
self
.
verbose
==
2
:
if
finalize
:
numdigits
=
int
(
np
.
log10
(
self
.
target
))
+
1
count
=
(
'%'
+
str
(
numdigits
)
+
'd/%d'
)
%
(
current
,
self
.
target
)
info
=
count
+
info
for
k
in
self
.
_values_order
:
info
+=
' - %s:'
%
k
avg
=
np
.
mean
(
self
.
_values
[
k
][
0
]
/
max
(
1
,
self
.
_values
[
k
][
1
]))
if
avg
>
1e-3
:
info
+=
' %.4f'
%
avg
else
:
info
+=
' %.4e'
%
avg
info
+=
'
\n
'
sys
.
stdout
.
write
(
info
)
sys
.
stdout
.
flush
()
self
.
_last_update
=
now
def
add
(
self
,
n
,
values
=
None
):
self
.
update
(
self
.
_seen_so_far
+
n
,
values
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录