提交 578f83f0 编写于 作者: C chenguowei01

add pretrained model loaded

上级 b5fb876f
...@@ -20,11 +20,12 @@ from paddle.fluid.dygraph.base import to_variable ...@@ -20,11 +20,12 @@ from paddle.fluid.dygraph.base import to_variable
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from datasets.dataset import Dataset from datasets import Dataset
import transforms as T import transforms as T
import models import models
import utils.logging as logging import utils.logging as logging
from utils import get_environ_info from utils import get_environ_info
from utils import load_pretrained_model
from val import evaluate from val import evaluate
...@@ -133,6 +134,8 @@ def train(model, ...@@ -133,6 +134,8 @@ def train(model,
os.remove(save_dir) os.remove(save_dir)
os.makedirs(save_dir) os.makedirs(save_dir)
load_pretrained_model(model, pretrained_model)
data_generator = train_dataset.generator( data_generator = train_dataset.generator(
batch_size=batch_size, drop_last=True) batch_size=batch_size, drop_last=True)
num_steps_each_epoch = train_dataset.num_samples // args.batch_size num_steps_each_epoch = train_dataset.num_samples // args.batch_size
......
...@@ -12,15 +12,13 @@ ...@@ -12,15 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
import time
import os import os
import os.path as osp import os.path as osp
import numpy as np import numpy as np
import six
import yaml
import math import math
import cv2 import cv2
import paddle.fluid as fluid
from . import logging from . import logging
...@@ -58,6 +56,37 @@ def get_environ_info(): ...@@ -58,6 +56,37 @@ def get_environ_info():
return info return info
def load_pretrained_model(model, pretrained_model):
logging.info('Load pretrained model!')
if pretrained_model is not None:
if osp.exists(pretrained_model):
ckpt_path = osp.join(pretrained_model, 'model')
para_state_dict, _ = fluid.load_dygraph(ckpt_path)
model_state_dict = model.state_dict()
keys = model_state_dict.keys()
num_params_loaded = 0
for k in keys:
if k not in para_state_dict:
logging.warning("{} is not in pretrained model".format(k))
elif list(para_state_dict[k].shape) != list(
model_state_dict[k].shape):
logging.warning(
"[SKIP] Shape of pretrained params {} doesn't match.(Pretrained: {}, Actual: {})"
.format(k, para_state_dict[k].shape,
model_state_dict[k].shape))
else:
model_state_dict[k] = para_state_dict[k]
num_params_loaded += 1
model.set_dict(model_state_dict)
logging.info("There are {}/{} varaibles are loaded.".format(
num_params_loaded, len(model_state_dict)))
else:
raise ValueError(
'The pretrained model directory is not Found: {}'.formnat(
pretrained_model))
def visualize(image, result, save_dir=None, weight=0.6): def visualize(image, result, save_dir=None, weight=0.6):
""" """
Convert segment result to color image, and save added image. Convert segment result to color image, and save added image.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册