提交 9397dbe7 编写于 作者: F FDInSky 提交者: qingqing01

Fix bug in deeplabv3plus when training on Windows system. (#2867)

上级 e0f58e80
......@@ -32,7 +32,7 @@ depthwise_regularizer = fluid.regularizer.L2DecayRegularizer(
def scope(name):
global name_scope
bk = name_scope
name_scope = name_scope + name + '/'
name_scope = name_scope + name + os.sep
yield
name_scope = bk
......@@ -62,11 +62,11 @@ def conv(*args, **kargs):
init_std = 0.09
elif "logit" in name_scope:
init_std = 0.01
elif name_scope.endswith('depthwise/'):
elif name_scope.endswith('depthwise' + os.sep):
init_std = 0.33
else:
init_std = 0.06
if name_scope.endswith('depthwise/'):
if name_scope.endswith('depthwise' + os.sep):
regularizer = depthwise_regularizer
else:
regularizer = None
......
......@@ -45,25 +45,22 @@ def slice_with_pad(a, s, value=0):
class CityscapeDataset:
def __init__(self, dataset_dir, subset='train', config=default_config):
label_dirname = os.path.join(dataset_dir, 'gtFine/' + subset)
if six.PY2:
import commands
label_files = commands.getoutput(
"find %s -type f | grep labelTrainIds | sort" %
label_dirname).splitlines()
else:
import subprocess
label_files = subprocess.getstatusoutput(
"find %s -type f | grep labelTrainIds | sort" %
label_dirname)[-1].splitlines()
self.label_files = label_files
self.label_dirname = label_dirname
with open(os.path.join(dataset_dir, subset + '.list'), 'r') as fr:
file_list = fr.readlines()
all_images = []
all_labels = []
for i in range(len(file_list)):
img_gt = file_list[i].strip().split(' ')
all_images.append(os.path.join(dataset_dir, img_gt[0]))
all_labels.append(os.path.join(dataset_dir, img_gt[1]))
self.label_files = all_labels
self.img_files = all_images
self.index = 0
self.subset = subset
self.dataset_dir = dataset_dir
self.config = config
self.reset()
print("total number", len(label_files))
def reset(self, shuffle=False):
self.index = 0
......@@ -79,10 +76,7 @@ class CityscapeDataset:
shape = self.config["crop_size"]
while True:
ln = self.label_files[self.index]
img_name = os.path.join(
self.dataset_dir,
'leftImg8bit/' + self.subset + ln[len(self.label_dirname):])
img_name = img_name.replace('gtFine_labelTrainIds', 'leftImg8bit')
img_name = self.img_files[self.index]
label = cv2.imread(ln)
img = cv2.imread(img_name)
if img is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册