提交 3d021fb2 编写于 作者: M minqiyang

Port object_detection to Python3

上级 06820fcb
...@@ -22,6 +22,7 @@ import xml.etree.ElementTree ...@@ -22,6 +22,7 @@ import xml.etree.ElementTree
import os import os
import time import time
import copy import copy
import six
class Settings(object): class Settings(object):
...@@ -151,7 +152,7 @@ def preprocess(img, bbox_labels, mode, settings): ...@@ -151,7 +152,7 @@ def preprocess(img, bbox_labels, mode, settings):
mirror = int(random.uniform(0, 2)) mirror = int(random.uniform(0, 2))
if mirror == 1: if mirror == 1:
img = img[:, ::-1, :] img = img[:, ::-1, :]
for i in xrange(len(sampled_labels)): for i in six.moves.xrange(len(sampled_labels)):
tmp = sampled_labels[i][1] tmp = sampled_labels[i][1]
sampled_labels[i][1] = 1 - sampled_labels[i][3] sampled_labels[i][1] = 1 - sampled_labels[i][3]
sampled_labels[i][3] = 1 - tmp sampled_labels[i][3] = 1 - tmp
......
...@@ -88,16 +88,16 @@ def train(args, ...@@ -88,16 +88,16 @@ def train(args,
if 'coco' in data_args.dataset: if 'coco' in data_args.dataset:
# learning rate decay in 12, 19 pass, respectively # learning rate decay in 12, 19 pass, respectively
if '2014' in train_file_list: if '2014' in train_file_list:
epocs = 82783 / batch_size epocs = 82783 // batch_size
boundaries = [epocs * 12, epocs * 19] boundaries = [epocs * 12, epocs * 19]
elif '2017' in train_file_list: elif '2017' in train_file_list:
epocs = 118287 / batch_size epocs = 118287 // batch_size
boundaries = [epocs * 12, epocs * 19] boundaries = [epocs * 12, epocs * 19]
values = [ values = [
learning_rate, learning_rate * 0.5, learning_rate * 0.25 learning_rate, learning_rate * 0.5, learning_rate * 0.25
] ]
elif 'pascalvoc' in data_args.dataset: elif 'pascalvoc' in data_args.dataset:
epocs = 19200 / batch_size epocs = 19200 // batch_size
boundaries = [epocs * 40, epocs * 60, epocs * 80, epocs * 100] boundaries = [epocs * 40, epocs * 60, epocs * 80, epocs * 100]
values = [ values = [
learning_rate, learning_rate * 0.5, learning_rate * 0.25, learning_rate, learning_rate * 0.5, learning_rate * 0.25,
...@@ -137,7 +137,7 @@ def train(args, ...@@ -137,7 +137,7 @@ def train(args,
model_path = os.path.join(model_save_dir, postfix) model_path = os.path.join(model_save_dir, postfix)
if os.path.isdir(model_path): if os.path.isdir(model_path):
shutil.rmtree(model_path) shutil.rmtree(model_path)
print 'save models to %s' % (model_path) print('save models to %s' % (model_path))
fluid.io.save_persistables(exe, model_path) fluid.io.save_persistables(exe, model_path)
best_map = 0. best_map = 0.
...@@ -193,15 +193,15 @@ def train(args, ...@@ -193,15 +193,15 @@ def train(args,
total_time += end_time - start_time total_time += end_time - start_time
train_avg_loss = np.mean(every_pass_loss) train_avg_loss = np.mean(every_pass_loss)
if devices_num == 1: if devices_num == 1:
print ("kpis train_cost %s" % train_avg_loss) print("kpis train_cost %s" % train_avg_loss)
print ("kpis test_acc %s" % mean_map) print("kpis test_acc %s" % mean_map)
print ("kpis train_speed %s" % (total_time / epoch_idx)) print("kpis train_speed %s" % (total_time / epoch_idx))
else: else:
print ("kpis train_cost_card%s %s" % print("kpis train_cost_card%s %s" %
(devices_num, train_avg_loss)) (devices_num, train_avg_loss))
print ("kpis test_acc_card%s %s" % print("kpis test_acc_card%s %s" %
(devices_num, mean_map)) (devices_num, mean_map))
print ("kpis train_speed_card%s %f" % print("kpis train_speed_card%s %f" %
(devices_num, total_time / epoch_idx)) (devices_num, total_time / epoch_idx))
......
...@@ -16,8 +16,10 @@ ...@@ -16,8 +16,10 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import distutils.util import distutils.util
import numpy as np import numpy as np
import six
from paddle.fluid import core from paddle.fluid import core
...@@ -37,7 +39,7 @@ def print_arguments(args): ...@@ -37,7 +39,7 @@ def print_arguments(args):
:type args: argparse.Namespace :type args: argparse.Namespace
""" """
print("----------- Configuration Arguments -----------") print("----------- Configuration Arguments -----------")
for arg, value in sorted(vars(args).iteritems()): for arg, value in sorted(six.iteritems(vars(args))):
print("%s: %s" % (arg, value)) print("%s: %s" % (arg, value))
print("------------------------------------------------") print("------------------------------------------------")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册