未验证 提交 16350b38 编写于 作者: W wuzewu 提交者: GitHub

Merge pull request #326 from wuyefeilin/dygraph

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .train import train
from .val import evaluate
from .infer import infer
__all__ = ['train', 'evaluate', 'infer']
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from paddle.fluid.dygraph.base import to_variable
import numpy as np
import paddle.fluid as fluid
import cv2
import tqdm
import utils
import utils.logging as logging
def mkdir(path):
sub_dir = os.path.dirname(path)
if not os.path.exists(sub_dir):
os.makedirs(sub_dir)
def infer(model, test_dataset=None, model_dir=None, save_dir='output'):
ckpt_path = os.path.join(model_dir, 'model')
para_state_dict, opti_state_dict = fluid.load_dygraph(ckpt_path)
model.set_dict(para_state_dict)
model.eval()
added_saved_dir = os.path.join(save_dir, 'added')
pred_saved_dir = os.path.join(save_dir, 'prediction')
logging.info("Start to predict...")
for im, im_info, im_path in tqdm.tqdm(test_dataset):
im = to_variable(im)
pred, _ = model(im)
pred = pred.numpy()
pred = np.squeeze(pred).astype('uint8')
for info in im_info[::-1]:
if info[0] == 'resize':
h, w = info[1][0], info[1][1]
pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST)
elif info[0] == 'padding':
h, w = info[1][0], info[1][1]
pred = pred[0:h, 0:w]
else:
raise Exception("Unexpected info '{}' in im_info".format(
info[0]))
im_file = im_path.replace(test_dataset.data_dir, '')
if im_file[0] == '/':
im_file = im_file[1:]
# save added image
added_image = utils.visualize(im_path, pred, weight=0.6)
added_image_path = os.path.join(added_saved_dir, im_file)
mkdir(added_image_path)
cv2.imwrite(added_image_path, added_image)
# save prediction
pred_im = utils.visualize(im_path, pred, weight=0.0)
pred_saved_path = os.path.join(pred_saved_dir, im_file)
mkdir(pred_saved_path)
cv2.imwrite(pred_saved_path, pred_im)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
from paddle.incubate.hapi.distributed import DistributedBatchSampler
import utils.logging as logging
from utils import load_pretrained_model
from utils import resume
from utils import Timer, calculate_eta
from .val import evaluate
def train(model,
train_dataset,
places=None,
eval_dataset=None,
optimizer=None,
save_dir='output',
num_epochs=100,
batch_size=2,
pretrained_model=None,
resume_model=None,
save_interval_epochs=1,
log_steps=10,
num_classes=None,
num_workers=8,
use_vdl=False):
ignore_index = model.ignore_index
nranks = ParallelEnv().nranks
start_epoch = 0
if resume_model is not None:
start_epoch = resume(model, optimizer, resume_model)
elif pretrained_model is not None:
load_pretrained_model(model, pretrained_model)
if not os.path.isdir(save_dir):
if os.path.exists(save_dir):
os.remove(save_dir)
os.makedirs(save_dir)
if nranks > 1:
strategy = fluid.dygraph.prepare_context()
ddp_model = fluid.dygraph.DataParallel(model, strategy)
batch_sampler = DistributedBatchSampler(
train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
places=places,
num_workers=num_workers,
return_list=True,
)
if use_vdl:
from visualdl import LogWriter
log_writer = LogWriter(save_dir)
timer = Timer()
avg_loss = 0.0
steps_per_epoch = len(batch_sampler)
total_steps = steps_per_epoch * (num_epochs - start_epoch)
num_steps = 0
best_mean_iou = -1.0
best_model_epoch = -1
train_reader_cost = 0.0
train_batch_cost = 0.0
for epoch in range(start_epoch, num_epochs):
timer.start()
for step, data in enumerate(loader):
train_reader_cost += timer.elapsed_time()
images = data[0]
labels = data[1].astype('int64')
if nranks > 1:
loss = ddp_model(images, labels)
# apply_collective_grads sum grads over multiple gpus.
loss = ddp_model.scale_loss(loss)
loss.backward()
ddp_model.apply_collective_grads()
else:
loss = model(images, labels)
loss.backward()
optimizer.minimize(loss)
model.clear_gradients()
avg_loss += loss.numpy()[0]
lr = optimizer.current_step_lr()
num_steps += 1
train_batch_cost += timer.elapsed_time()
if num_steps % log_steps == 0 and ParallelEnv().local_rank == 0:
avg_loss /= log_steps
avg_train_reader_cost = train_reader_cost / log_steps
avg_train_batch_cost = train_batch_cost / log_steps
train_reader_cost = 0.0
train_batch_cost = 0.0
remain_steps = total_steps - num_steps
eta = calculate_eta(remain_steps, avg_train_batch_cost)
logging.info(
"[TRAIN] Epoch={}/{}, Step={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.4f} | ETA {}"
.format(epoch + 1, num_epochs, step + 1, steps_per_epoch,
avg_loss * nranks, lr, avg_train_batch_cost,
avg_train_reader_cost, eta))
if use_vdl:
log_writer.add_scalar('Train/loss', avg_loss, num_steps)
log_writer.add_scalar('Train/lr', lr, num_steps)
log_writer.add_scalar('Train/batch_cost',
avg_train_batch_cost, num_steps)
log_writer.add_scalar('Train/reader_cost',
avg_train_reader_cost, num_steps)
avg_loss = 0.0
timer.restart()
if ((epoch + 1) % save_interval_epochs == 0
or epoch + 1 == num_epochs) and ParallelEnv().local_rank == 0:
current_save_dir = os.path.join(save_dir,
"epoch_{}".format(epoch + 1))
if not os.path.isdir(current_save_dir):
os.makedirs(current_save_dir)
fluid.save_dygraph(model.state_dict(),
os.path.join(current_save_dir, 'model'))
fluid.save_dygraph(optimizer.state_dict(),
os.path.join(current_save_dir, 'model'))
if eval_dataset is not None:
mean_iou, avg_acc = evaluate(
model,
eval_dataset,
model_dir=current_save_dir,
num_classes=num_classes,
ignore_index=ignore_index,
epoch_id=epoch + 1)
if mean_iou > best_mean_iou:
best_mean_iou = mean_iou
best_model_epoch = epoch + 1
best_model_dir = os.path.join(save_dir, "best_model")
fluid.save_dygraph(model.state_dict(),
os.path.join(best_model_dir, 'model'))
logging.info(
'Current evaluated best model in eval_dataset is epoch_{}, miou={:4f}'
.format(best_model_epoch, best_mean_iou))
if use_vdl:
log_writer.add_scalar('Evaluate/mIoU', mean_iou, epoch + 1)
log_writer.add_scalar('Evaluate/aAcc', avg_acc, epoch + 1)
model.train()
if use_vdl:
log_writer.close()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import tqdm
import cv2
from paddle.fluid.dygraph.base import to_variable
import paddle.fluid as fluid
import utils.logging as logging
from utils import ConfusionMatrix
from utils import Timer, calculate_eta
def evaluate(model,
eval_dataset=None,
model_dir=None,
num_classes=None,
ignore_index=255,
epoch_id=None):
ckpt_path = os.path.join(model_dir, 'model')
para_state_dict, opti_state_dict = fluid.load_dygraph(ckpt_path)
model.set_dict(para_state_dict)
model.eval()
total_steps = len(eval_dataset)
conf_mat = ConfusionMatrix(num_classes, streaming=True)
logging.info(
"Start to evaluating(total_samples={}, total_steps={})...".format(
len(eval_dataset), total_steps))
timer = Timer()
timer.start()
for step, (im, im_info, label) in tqdm.tqdm(
enumerate(eval_dataset), total=total_steps):
im = to_variable(im)
pred, _ = model(im)
pred = pred.numpy().astype('float32')
pred = np.squeeze(pred)
for info in im_info[::-1]:
if info[0] == 'resize':
h, w = info[1][0], info[1][1]
pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST)
elif info[0] == 'padding':
h, w = info[1][0], info[1][1]
pred = pred[0:h, 0:w]
else:
raise Exception("Unexpected info '{}' in im_info".format(
info[0]))
pred = pred[np.newaxis, :, :, np.newaxis]
pred = pred.astype('int64')
mask = label != ignore_index
conf_mat.calculate(pred=pred, label=label, ignore=mask)
_, iou = conf_mat.mean_iou()
time_step = timer.elapsed_time()
remain_step = total_steps - step - 1
logging.debug(
"[EVAL] Epoch={}, Step={}/{}, iou={:4f}, sec/step={:.4f} | ETA {}".
format(epoch_id, step + 1, total_steps, iou, time_step,
calculate_eta(remain_step, time_step)))
timer.restart()
category_iou, miou = conf_mat.mean_iou()
category_acc, macc = conf_mat.accuracy()
logging.info("[EVAL] #Images={} mAcc={:.4f} mIoU={:.4f}".format(
len(eval_dataset), macc, miou))
logging.info("[EVAL] Category IoU: " + str(category_iou))
logging.info("[EVAL] Category Acc: " + str(category_acc))
logging.info("[EVAL] Kappa:{:.4f} ".format(conf_mat.kappa()))
return miou, macc
......@@ -28,6 +28,7 @@ from models import MODELS
import utils
import utils.logging as logging
from utils import get_environ_info
from core import infer
def parse_args():
......@@ -81,54 +82,6 @@ def parse_args():
return parser.parse_args()
def mkdir(path):
sub_dir = os.path.dirname(path)
if not os.path.exists(sub_dir):
os.makedirs(sub_dir)
def infer(model, test_dataset=None, model_dir=None, save_dir='output'):
ckpt_path = os.path.join(model_dir, 'model')
para_state_dict, opti_state_dict = fluid.load_dygraph(ckpt_path)
model.set_dict(para_state_dict)
model.eval()
added_saved_dir = os.path.join(save_dir, 'added')
pred_saved_dir = os.path.join(save_dir, 'prediction')
logging.info("Start to predict...")
for im, im_info, im_path in tqdm.tqdm(test_dataset):
im = to_variable(im)
pred, _ = model(im, mode='test')
pred = pred.numpy()
pred = np.squeeze(pred).astype('uint8')
for info in im_info[::-1]:
if info[0] == 'resize':
h, w = info[1][0], info[1][1]
pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST)
elif info[0] == 'padding':
h, w = info[1][0], info[1][1]
pred = pred[0:h, 0:w]
else:
raise Exception("Unexpected info '{}' in im_info".format(
info[0]))
im_file = im_path.replace(test_dataset.data_dir, '')
if im_file[0] == '/':
im_file = im_file[1:]
# save added image
added_image = utils.visualize(im_path, pred, weight=0.6)
added_image_path = os.path.join(added_saved_dir, im_file)
mkdir(added_image_path)
cv2.imwrite(added_image_path, added_image)
# save prediction
pred_im = utils.visualize(im_path, pred, weight=0.0)
pred_saved_path = os.path.join(pred_saved_dir, im_file)
mkdir(pred_saved_path)
cv2.imwrite(pred_saved_path, pred_im)
def main(args):
env_info = get_environ_info()
places = fluid.CUDAPlace(ParallelEnv().dev_id) \
......
......@@ -19,10 +19,8 @@ import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
try:
from paddle.fluid.dygraph import SyncBatchNorm as BatchNorm
except:
from paddle.fluid.dygraph import BatchNorm
from paddle.fluid.initializer import Normal
from paddle.fluid.dygraph import SyncBatchNorm as BatchNorm
__all__ = [
"HRNet_W18_Small_V1", "HRNet_W18_Small_V2", "HRNet_W18", "HRNet_W30",
......@@ -140,7 +138,8 @@ class HRNet(fluid.dygraph.Layer):
filter_size=1,
stride=1,
padding=0,
param_attr=ParamAttr(name='conv-1_weights'))
param_attr=ParamAttr(
initializer=Normal(scale=0.001), name='conv-1_weights'))
def forward(self, x, label=None, mode='train'):
input_shape = x.shape[2:]
......@@ -167,7 +166,7 @@ class HRNet(fluid.dygraph.Layer):
logit = self.conv_last_1(x)
logit = fluid.layers.resize_bilinear(logit, input_shape)
if mode == 'train':
if self.training:
if label is None:
raise Exception('Label is need during training')
return self._get_loss(logit, label)
......@@ -218,14 +217,19 @@ class ConvBNLayer(fluid.dygraph.Layer):
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
param_attr=ParamAttr(
initializer=Normal(scale=0.001), name=name + "_weights"),
bias_attr=False)
bn_name = name + '_bn'
self._batch_norm = BatchNorm(
num_filters,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
param_attr=ParamAttr(
name=bn_name + '_scale',
initializer=fluid.initializer.Constant(1.0)),
bias_attr=ParamAttr(
bn_name + '_offset',
initializer=fluid.initializer.Constant(0.0)),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
......@@ -646,7 +650,7 @@ class FuseLayers(fluid.dygraph.Layer):
y = self.residual_func_list[residual_func_idx](input[j])
residual_func_idx += 1
y = fluid.layers.resize_nearest(input=y, scale=2**(j - i))
y = fluid.layers.resize_bilinear(input=y, scale=2**(j - i))
residual = fluid.layers.elementwise_add(
x=residual, y=y, act=None)
elif j < i:
......
......@@ -14,26 +14,23 @@
import paddle.fluid as fluid
from paddle.fluid.dygraph import Conv2D, Pool2D
try:
from paddle.fluid.dygraph import SyncBatchNorm as BatchNorm
except:
from paddle.fluid.dygraph import BatchNorm
from paddle.fluid.dygraph import SyncBatchNorm as BatchNorm
class UNet(fluid.dygraph.Layer):
def __init__(self, num_classes, ignore_index=255):
super().__init__()
super(UNet, self).__init__()
self.encode = UnetEncoder()
self.decode = UnetDecode()
self.get_logit = GetLogit(64, num_classes)
self.ignore_index = ignore_index
self.EPS = 1e-5
def forward(self, x, label=None, mode='train'):
def forward(self, x, label=None):
encode_data, short_cuts = self.encode(x)
decode_data = self.decode(encode_data, short_cuts)
logit = self.get_logit(decode_data)
if mode == 'train':
if self.training:
return self._get_loss(logit, label)
else:
score_map = fluid.layers.softmax(logit, axis=1)
......@@ -52,7 +49,7 @@ class UNet(fluid.dygraph.Layer):
label,
ignore_index=self.ignore_index,
return_softmax=True,
axis=1)
axis=-1)
loss = loss * mask
avg_loss = fluid.layers.mean(loss) / (
......@@ -65,7 +62,7 @@ class UNet(fluid.dygraph.Layer):
class UnetEncoder(fluid.dygraph.Layer):
def __init__(self):
super().__init__()
super(UnetEncoder, self).__init__()
self.double_conv = DoubleConv(3, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
......@@ -88,7 +85,7 @@ class UnetEncoder(fluid.dygraph.Layer):
class UnetDecode(fluid.dygraph.Layer):
def __init__(self):
super().__init__()
super(UnetDecode, self).__init__()
self.up1 = Up(512, 256)
self.up2 = Up(256, 128)
self.up3 = Up(128, 64)
......@@ -104,7 +101,7 @@ class UnetDecode(fluid.dygraph.Layer):
class DoubleConv(fluid.dygraph.Layer):
def __init__(self, num_channels, num_filters):
super().__init__()
super(DoubleConv, self).__init__()
self.conv0 = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
......@@ -132,7 +129,7 @@ class DoubleConv(fluid.dygraph.Layer):
class Down(fluid.dygraph.Layer):
def __init__(self, num_channels, num_filters):
super().__init__()
super(Down, self).__init__()
self.max_pool = Pool2D(
pool_size=2, pool_type='max', pool_stride=2, pool_padding=0)
self.double_conv = DoubleConv(num_channels, num_filters)
......@@ -145,7 +142,7 @@ class Down(fluid.dygraph.Layer):
class Up(fluid.dygraph.Layer):
def __init__(self, num_channels, num_filters):
super().__init__()
super(Up, self).__init__()
self.double_conv = DoubleConv(2 * num_channels, num_filters)
def forward(self, x, short_cut):
......@@ -158,7 +155,7 @@ class Up(fluid.dygraph.Layer):
class GetLogit(fluid.dygraph.Layer):
def __init__(self, num_channels, num_classes):
super().__init__()
super(GetLogit, self).__init__()
self.conv = Conv2D(
num_channels=num_channels,
num_filters=num_classes,
......
......@@ -28,7 +28,7 @@ from utils import get_environ_info
from utils import load_pretrained_model
from utils import resume
from utils import Timer, calculate_eta
from val import evaluate
from core import train
def parse_args():
......@@ -128,132 +128,6 @@ def parse_args():
return parser.parse_args()
def train(model,
train_dataset,
places=None,
eval_dataset=None,
optimizer=None,
save_dir='output',
num_epochs=100,
batch_size=2,
pretrained_model=None,
resume_model=None,
save_interval_epochs=1,
log_steps=10,
num_classes=None,
num_workers=8,
use_vdl=False):
ignore_index = model.ignore_index
nranks = ParallelEnv().nranks
start_epoch = 0
if resume_model is not None:
start_epoch = resume(model, optimizer, resume_model)
elif pretrained_model is not None:
load_pretrained_model(model, pretrained_model)
if not os.path.isdir(save_dir):
if os.path.exists(save_dir):
os.remove(save_dir)
os.makedirs(save_dir)
if nranks > 1:
strategy = fluid.dygraph.prepare_context()
model_parallel = fluid.dygraph.DataParallel(model, strategy)
batch_sampler = DistributedBatchSampler(
train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
places=places,
num_workers=num_workers,
return_list=True,
)
if use_vdl:
from visualdl import LogWriter
log_writer = LogWriter(save_dir)
timer = Timer()
timer.start()
avg_loss = 0.0
steps_per_epoch = len(batch_sampler)
total_steps = steps_per_epoch * (num_epochs - start_epoch)
num_steps = 0
best_mean_iou = -1.0
best_model_epoch = -1
for epoch in range(start_epoch, num_epochs):
for step, data in enumerate(loader):
images = data[0]
labels = data[1].astype('int64')
if nranks > 1:
loss = model_parallel(images, labels, mode='train')
loss = model_parallel.scale_loss(loss)
loss.backward()
model_parallel.apply_collective_grads()
else:
loss = model(images, labels, mode='train')
loss.backward()
optimizer.minimize(loss)
model.clear_gradients()
avg_loss += loss.numpy()[0]
lr = optimizer.current_step_lr()
num_steps += 1
if num_steps % log_steps == 0 and ParallelEnv().local_rank == 0:
avg_loss /= log_steps
time_step = timer.elapsed_time() / log_steps
remain_steps = total_steps - num_steps
logging.info(
"[TRAIN] Epoch={}/{}, Step={}/{}, loss={:.4f}, lr={:.6f}, sec/step={:.4f} | ETA {}"
.format(epoch + 1, num_epochs, step + 1, steps_per_epoch,
avg_loss, lr, time_step,
calculate_eta(remain_steps, time_step)))
if use_vdl:
log_writer.add_scalar('Train/loss', avg_loss, num_steps)
log_writer.add_scalar('Train/lr', lr, num_steps)
avg_loss = 0.0
timer.restart()
if ((epoch + 1) % save_interval_epochs == 0
or epoch + 1 == num_epochs) and ParallelEnv().local_rank == 0:
current_save_dir = os.path.join(save_dir,
"epoch_{}".format(epoch + 1))
if not os.path.isdir(current_save_dir):
os.makedirs(current_save_dir)
fluid.save_dygraph(model.state_dict(),
os.path.join(current_save_dir, 'model'))
fluid.save_dygraph(optimizer.state_dict(),
os.path.join(current_save_dir, 'model'))
if eval_dataset is not None:
mean_iou, mean_acc = evaluate(
model,
eval_dataset,
model_dir=current_save_dir,
num_classes=num_classes,
ignore_index=ignore_index,
epoch_id=epoch + 1)
if mean_iou > best_mean_iou:
best_mean_iou = mean_iou
best_model_epoch = epoch + 1
best_model_dir = os.path.join(save_dir, "best_model")
fluid.save_dygraph(model.state_dict(),
os.path.join(best_model_dir, 'model'))
logging.info(
'Current evaluated best model in eval_dataset is epoch_{}, miou={:4f}'
.format(best_model_epoch, best_mean_iou))
if use_vdl:
log_writer.add_scalar('Evaluate/mean_iou', mean_iou,
epoch + 1)
log_writer.add_scalar('Evaluate/mean_acc', mean_acc,
epoch + 1)
model.train()
if use_vdl:
log_writer.close()
def main(args):
env_info = get_environ_info()
places = fluid.CUDAPlace(ParallelEnv().dev_id) \
......
......@@ -32,6 +32,7 @@ import utils.logging as logging
from utils import get_environ_info
from utils import ConfusionMatrix
from utils import Timer, calculate_eta
from core import evaluate
def parse_args():
......@@ -73,65 +74,6 @@ def parse_args():
return parser.parse_args()
def evaluate(model,
eval_dataset=None,
model_dir=None,
num_classes=None,
ignore_index=255,
epoch_id=None):
ckpt_path = os.path.join(model_dir, 'model')
para_state_dict, opti_state_dict = fluid.load_dygraph(ckpt_path)
model.set_dict(para_state_dict)
model.eval()
total_steps = len(eval_dataset)
conf_mat = ConfusionMatrix(num_classes, streaming=True)
logging.info(
"Start to evaluating(total_samples={}, total_steps={})...".format(
len(eval_dataset), total_steps))
timer = Timer()
timer.start()
for step, (im, im_info, label) in enumerate(eval_dataset):
im = to_variable(im)
pred, _ = model(im, mode='eval')
pred = pred.numpy().astype('float32')
pred = np.squeeze(pred)
for info in im_info[::-1]:
if info[0] == 'resize':
h, w = info[1][0], info[1][1]
pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST)
elif info[0] == 'padding':
h, w = info[1][0], info[1][1]
pred = pred[0:h, 0:w]
else:
raise Exception("Unexpected info '{}' in im_info".format(
info[0]))
pred = pred[np.newaxis, :, :, np.newaxis]
pred = pred.astype('int64')
mask = label != ignore_index
conf_mat.calculate(pred=pred, label=label, ignore=mask)
_, iou = conf_mat.mean_iou()
time_step = timer.elapsed_time()
remain_step = total_steps - step - 1
logging.info(
"[EVAL] Epoch={}, Step={}/{}, iou={:4f}, sec/step={:.4f} | ETA {}".
format(epoch_id, step + 1, total_steps, iou, time_step,
calculate_eta(remain_step, time_step)))
timer.restart()
category_iou, miou = conf_mat.mean_iou()
category_acc, macc = conf_mat.accuracy()
logging.info("[EVAL] #image={} acc={:.4f} IoU={:.4f}".format(
len(eval_dataset), macc, miou))
logging.info("[EVAL] Category IoU: " + str(category_iou))
logging.info("[EVAL] Category Acc: " + str(category_acc))
logging.info("[EVAL] Kappa:{:.4f} ".format(conf_mat.kappa()))
return miou, macc
def main(args):
env_info = get_environ_info()
places = fluid.CUDAPlace(ParallelEnv().dev_id) \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册