未验证 提交 e99b0ac6 编写于 作者: W wuzewu 提交者: GitHub

Merge pull request #367 from wuyefeilin/dygraph

update train.py and get_environ_info.py
...@@ -79,7 +79,7 @@ def train(model, ...@@ -79,7 +79,7 @@ def train(model,
train_batch_cost = 0.0 train_batch_cost = 0.0
timer.start() timer.start()
iter = 0 iter = start_iter
while iter < iters: while iter < iters:
for data in loader: for data in loader:
iter += 1 iter += 1
......
...@@ -64,7 +64,6 @@ class UNet(fluid.dygraph.Layer): ...@@ -64,7 +64,6 @@ class UNet(fluid.dygraph.Layer):
""" """
if pretrained_model is not None: if pretrained_model is not None:
if os.path.exists(pretrained_model): if os.path.exists(pretrained_model):
utils.load_pretrained_model(self.backbone, pretrained_model)
utils.load_pretrained_model(self, pretrained_model) utils.load_pretrained_model(self, pretrained_model)
else: else:
raise Exception('Pretrained model is not found: {}'.format( raise Exception('Pretrained model is not found: {}'.format(
......
...@@ -83,11 +83,16 @@ def get_environ_info(): ...@@ -83,11 +83,16 @@ def get_environ_info():
env_info = {} env_info = {}
env_info['System Platform'] = sys.platform env_info['System Platform'] = sys.platform
if env_info['System Platform'] == 'linux': if env_info['System Platform'] == 'linux':
lsb_v = subprocess.check_output(['lsb_release', '-v']).decode().strip() try:
lsb_v = lsb_v.replace('\t', ' ') lsb_v = subprocess.check_output(['lsb_release',
lsb_d = subprocess.check_output(['lsb_release', '-d']).decode().strip() '-v']).decode().strip()
lsb_d = lsb_d.replace('\t', ' ') lsb_v = lsb_v.replace('\t', ' ')
env_info['LSB'] = [lsb_v, lsb_d] lsb_d = subprocess.check_output(['lsb_release',
'-d']).decode().strip()
lsb_d = lsb_d.replace('\t', ' ')
env_info['LSB'] = [lsb_v, lsb_d]
except:
pass
env_info['Python'] = sys.version.replace('\n', '') env_info['Python'] = sys.version.replace('\n', '')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册