提交 578f83f0 编写于 作者: C chenguowei01

add pretrained model loaded

上级 b5fb876f
......@@ -20,11 +20,12 @@ from paddle.fluid.dygraph.base import to_variable
import numpy as np
import paddle.fluid as fluid
from datasets.dataset import Dataset
from datasets import Dataset
import transforms as T
import models
import utils.logging as logging
from utils import get_environ_info
from utils import load_pretrained_model
from val import evaluate
......@@ -133,6 +134,8 @@ def train(model,
os.remove(save_dir)
os.makedirs(save_dir)
load_pretrained_model(model, pretrained_model)
data_generator = train_dataset.generator(
batch_size=batch_size, drop_last=True)
num_steps_each_epoch = train_dataset.num_samples // args.batch_size
......
......@@ -12,15 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import time
import os
import os.path as osp
import numpy as np
import six
import yaml
import math
import cv2
import paddle.fluid as fluid
from . import logging
......@@ -58,6 +56,37 @@ def get_environ_info():
return info
def load_pretrained_model(model, pretrained_model):
logging.info('Load pretrained model!')
if pretrained_model is not None:
if osp.exists(pretrained_model):
ckpt_path = osp.join(pretrained_model, 'model')
para_state_dict, _ = fluid.load_dygraph(ckpt_path)
model_state_dict = model.state_dict()
keys = model_state_dict.keys()
num_params_loaded = 0
for k in keys:
if k not in para_state_dict:
logging.warning("{} is not in pretrained model".format(k))
elif list(para_state_dict[k].shape) != list(
model_state_dict[k].shape):
logging.warning(
"[SKIP] Shape of pretrained params {} doesn't match.(Pretrained: {}, Actual: {})"
.format(k, para_state_dict[k].shape,
model_state_dict[k].shape))
else:
model_state_dict[k] = para_state_dict[k]
num_params_loaded += 1
model.set_dict(model_state_dict)
logging.info("There are {}/{} varaibles are loaded.".format(
num_params_loaded, len(model_state_dict)))
else:
raise ValueError(
'The pretrained model directory is not Found: {}'.formnat(
pretrained_model))
def visualize(image, result, save_dir=None, weight=0.6):
"""
Convert segment result to color image, and save added image.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册