From c3bae66b87863e9cdd24436e8aeb49772a8ce22b Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Thu, 18 Jun 2020 11:12:19 +0800 Subject: [PATCH] update some --- dygraph/datasets/__init__.py | 1 + dygraph/train.py | 4 ++-- dygraph/utils/__init__.py | 1 - dygraph/utils/utils.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dygraph/datasets/__init__.py b/dygraph/datasets/__init__.py index ef47f906..072a82f7 100644 --- a/dygraph/datasets/__init__.py +++ b/dygraph/datasets/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from .optic_disc_seg import OpticDiscSeg +from .cityscapes import Cityscapes diff --git a/dygraph/train.py b/dygraph/train.py index dc22aca5..88b1ccb6 100644 --- a/dygraph/train.py +++ b/dygraph/train.py @@ -78,7 +78,7 @@ def parse_args(): parser.add_argument( '--pretrained_model', dest='pretrained_model', - help='The path of pretrianed weight', + help='The path of pretrained weight', type=str, default=None) parser.add_argument( @@ -161,7 +161,7 @@ def train(model, optimizer.minimize(loss) model.clear_gradients() logging.info("[TRAIN] Epoch={}/{}, Step={}/{}, loss={}".format( - epoch + 1, num_epochs, step + 1, num_steps_each_epoch, + epoch + 1, num_epochs, step + 1, len(batch_sampler), loss.numpy())) if ((epoch + 1) % save_interval_epochs == 0 diff --git a/dygraph/utils/__init__.py b/dygraph/utils/__init__.py index aadc4342..7579cf7f 100644 --- a/dygraph/utils/__init__.py +++ b/dygraph/utils/__init__.py @@ -16,4 +16,3 @@ from . import logging from . import download from .metrics import ConfusionMatrix from .utils import * -from .distributed import DistributedBatchSampler diff --git a/dygraph/utils/utils.py b/dygraph/utils/utils.py index a16f8227..7a450b35 100644 --- a/dygraph/utils/utils.py +++ b/dygraph/utils/utils.py @@ -48,8 +48,8 @@ def get_environ_info(): def load_pretrained_model(model, pretrained_model): - logging.info('Load pretrained model!') if pretrained_model is not None: + logging.info('Load pretrained model!') if os.path.exists(pretrained_model): ckpt_path = os.path.join(pretrained_model, 'model') para_state_dict, _ = fluid.load_dygraph(ckpt_path) -- GitLab