提交 9397dbe7 编写于 作者: F FDInSky 提交者: qingqing01

Fix bug in deeplabv3plus when training on Windows system. (#2867)

上级 e0f58e80
...@@ -32,7 +32,7 @@ depthwise_regularizer = fluid.regularizer.L2DecayRegularizer( ...@@ -32,7 +32,7 @@ depthwise_regularizer = fluid.regularizer.L2DecayRegularizer(
def scope(name): def scope(name):
global name_scope global name_scope
bk = name_scope bk = name_scope
name_scope = name_scope + name + '/' name_scope = name_scope + name + os.sep
yield yield
name_scope = bk name_scope = bk
...@@ -62,11 +62,11 @@ def conv(*args, **kargs): ...@@ -62,11 +62,11 @@ def conv(*args, **kargs):
init_std = 0.09 init_std = 0.09
elif "logit" in name_scope: elif "logit" in name_scope:
init_std = 0.01 init_std = 0.01
elif name_scope.endswith('depthwise/'): elif name_scope.endswith('depthwise' + os.sep):
init_std = 0.33 init_std = 0.33
else: else:
init_std = 0.06 init_std = 0.06
if name_scope.endswith('depthwise/'): if name_scope.endswith('depthwise' + os.sep):
regularizer = depthwise_regularizer regularizer = depthwise_regularizer
else: else:
regularizer = None regularizer = None
......
...@@ -45,25 +45,22 @@ def slice_with_pad(a, s, value=0): ...@@ -45,25 +45,22 @@ def slice_with_pad(a, s, value=0):
class CityscapeDataset: class CityscapeDataset:
def __init__(self, dataset_dir, subset='train', config=default_config): def __init__(self, dataset_dir, subset='train', config=default_config):
label_dirname = os.path.join(dataset_dir, 'gtFine/' + subset) with open(os.path.join(dataset_dir, subset + '.list'), 'r') as fr:
if six.PY2: file_list = fr.readlines()
import commands all_images = []
label_files = commands.getoutput( all_labels = []
"find %s -type f | grep labelTrainIds | sort" % for i in range(len(file_list)):
label_dirname).splitlines() img_gt = file_list[i].strip().split(' ')
else: all_images.append(os.path.join(dataset_dir, img_gt[0]))
import subprocess all_labels.append(os.path.join(dataset_dir, img_gt[1]))
label_files = subprocess.getstatusoutput(
"find %s -type f | grep labelTrainIds | sort" % self.label_files = all_labels
label_dirname)[-1].splitlines() self.img_files = all_images
self.label_files = label_files
self.label_dirname = label_dirname
self.index = 0 self.index = 0
self.subset = subset self.subset = subset
self.dataset_dir = dataset_dir self.dataset_dir = dataset_dir
self.config = config self.config = config
self.reset() self.reset()
print("total number", len(label_files))
def reset(self, shuffle=False): def reset(self, shuffle=False):
self.index = 0 self.index = 0
...@@ -79,10 +76,7 @@ class CityscapeDataset: ...@@ -79,10 +76,7 @@ class CityscapeDataset:
shape = self.config["crop_size"] shape = self.config["crop_size"]
while True: while True:
ln = self.label_files[self.index] ln = self.label_files[self.index]
img_name = os.path.join( img_name = self.img_files[self.index]
self.dataset_dir,
'leftImg8bit/' + self.subset + ln[len(self.label_dirname):])
img_name = img_name.replace('gtFine_labelTrainIds', 'leftImg8bit')
label = cv2.imread(ln) label = cv2.imread(ln)
img = cv2.imread(img_name) img = cv2.imread(img_name)
if img is None: if img is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册