提交 c3bae66b 编写于 作者: C chenguowei01

update some

上级 6b4a7f02
......@@ -13,3 +13,4 @@
# limitations under the License.
from .optic_disc_seg import OpticDiscSeg
from .cityscapes import Cityscapes
......@@ -78,7 +78,7 @@ def parse_args():
parser.add_argument(
'--pretrained_model',
dest='pretrained_model',
help='The path of pretrianed weight',
help='The path of pretrained weight',
type=str,
default=None)
parser.add_argument(
......@@ -161,7 +161,7 @@ def train(model,
optimizer.minimize(loss)
model.clear_gradients()
logging.info("[TRAIN] Epoch={}/{}, Step={}/{}, loss={}".format(
epoch + 1, num_epochs, step + 1, num_steps_each_epoch,
epoch + 1, num_epochs, step + 1, len(batch_sampler),
loss.numpy()))
if ((epoch + 1) % save_interval_epochs == 0
......
......@@ -16,4 +16,3 @@ from . import logging
from . import download
from .metrics import ConfusionMatrix
from .utils import *
from .distributed import DistributedBatchSampler
......@@ -48,8 +48,8 @@ def get_environ_info():
def load_pretrained_model(model, pretrained_model):
logging.info('Load pretrained model!')
if pretrained_model is not None:
logging.info('Load pretrained model!')
if os.path.exists(pretrained_model):
ckpt_path = os.path.join(pretrained_model, 'model')
para_state_dict, _ = fluid.load_dygraph(ckpt_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册