diff --git a/dygraph/datasets/__init__.py b/dygraph/datasets/__init__.py index ef47f906eeab5e5f8fc53a3465dd69b8028918e3..072a82f7409a9369d2c3b1bdba603527eac0bb7f 100644 --- a/dygraph/datasets/__init__.py +++ b/dygraph/datasets/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from .optic_disc_seg import OpticDiscSeg +from .cityscapes import Cityscapes diff --git a/dygraph/train.py b/dygraph/train.py index dc22aca5b99aeab4f39f2b18ed0d9f5fd19588de..88b1ccb64bcb7f9862a4f81a17ee1cd392db36ae 100644 --- a/dygraph/train.py +++ b/dygraph/train.py @@ -78,7 +78,7 @@ def parse_args(): parser.add_argument( '--pretrained_model', dest='pretrained_model', - help='The path of pretrianed weight', + help='The path of pretrained weight', type=str, default=None) parser.add_argument( @@ -161,7 +161,7 @@ def train(model, optimizer.minimize(loss) model.clear_gradients() logging.info("[TRAIN] Epoch={}/{}, Step={}/{}, loss={}".format( - epoch + 1, num_epochs, step + 1, num_steps_each_epoch, + epoch + 1, num_epochs, step + 1, len(batch_sampler), loss.numpy())) if ((epoch + 1) % save_interval_epochs == 0 diff --git a/dygraph/utils/__init__.py b/dygraph/utils/__init__.py index aadc434287a17aea11d61137b63f26a5cdbc0443..7579cf7f0ed9f051b154d7bc2f99fc25ac246d4a 100644 --- a/dygraph/utils/__init__.py +++ b/dygraph/utils/__init__.py @@ -16,4 +16,3 @@ from . import logging from . import download from .metrics import ConfusionMatrix from .utils import * -from .distributed import DistributedBatchSampler diff --git a/dygraph/utils/utils.py b/dygraph/utils/utils.py index a16f82276f3b9274a4e8854fef1a2597ad98f3bd..7a450b352e0dcf98c1eeaa093878c9b3ba649dfd 100644 --- a/dygraph/utils/utils.py +++ b/dygraph/utils/utils.py @@ -48,8 +48,8 @@ def get_environ_info(): def load_pretrained_model(model, pretrained_model): - logging.info('Load pretrained model!') if pretrained_model is not None: + logging.info('Load pretrained model!') if os.path.exists(pretrained_model): ckpt_path = os.path.join(pretrained_model, 'model') para_state_dict, _ = fluid.load_dygraph(ckpt_path)