提交 c3bae66b 编写于 作者: C chenguowei01

update some

上级 6b4a7f02
...@@ -13,3 +13,4 @@ ...@@ -13,3 +13,4 @@
# limitations under the License. # limitations under the License.
from .optic_disc_seg import OpticDiscSeg from .optic_disc_seg import OpticDiscSeg
from .cityscapes import Cityscapes
...@@ -78,7 +78,7 @@ def parse_args(): ...@@ -78,7 +78,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--pretrained_model', '--pretrained_model',
dest='pretrained_model', dest='pretrained_model',
help='The path of pretrianed weight', help='The path of pretrained weight',
type=str, type=str,
default=None) default=None)
parser.add_argument( parser.add_argument(
...@@ -161,7 +161,7 @@ def train(model, ...@@ -161,7 +161,7 @@ def train(model,
optimizer.minimize(loss) optimizer.minimize(loss)
model.clear_gradients() model.clear_gradients()
logging.info("[TRAIN] Epoch={}/{}, Step={}/{}, loss={}".format( logging.info("[TRAIN] Epoch={}/{}, Step={}/{}, loss={}".format(
epoch + 1, num_epochs, step + 1, num_steps_each_epoch, epoch + 1, num_epochs, step + 1, len(batch_sampler),
loss.numpy())) loss.numpy()))
if ((epoch + 1) % save_interval_epochs == 0 if ((epoch + 1) % save_interval_epochs == 0
......
...@@ -16,4 +16,3 @@ from . import logging ...@@ -16,4 +16,3 @@ from . import logging
from . import download from . import download
from .metrics import ConfusionMatrix from .metrics import ConfusionMatrix
from .utils import * from .utils import *
from .distributed import DistributedBatchSampler
...@@ -48,8 +48,8 @@ def get_environ_info(): ...@@ -48,8 +48,8 @@ def get_environ_info():
def load_pretrained_model(model, pretrained_model): def load_pretrained_model(model, pretrained_model):
logging.info('Load pretrained model!')
if pretrained_model is not None: if pretrained_model is not None:
logging.info('Load pretrained model!')
if os.path.exists(pretrained_model): if os.path.exists(pretrained_model):
ckpt_path = os.path.join(pretrained_model, 'model') ckpt_path = os.path.join(pretrained_model, 'model')
para_state_dict, _ = fluid.load_dygraph(ckpt_path) para_state_dict, _ = fluid.load_dygraph(ckpt_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册