未验证 提交 4d1187d5 编写于 作者: H huangjun12 提交者: GitHub

update tsn Reader using dataloader and pipline (#4856)

上级 bde994e1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import numpy as np
from PIL import Image
class Scale(object):
"""
Scale images.
Args:
short_size(float | int): Short size of an image will be scaled to the short_size.
"""
def __init__(self, short_size):
self.short_size = short_size
def __call__(self, imgs):
"""
Performs resize operations.
Args:
imgs: List where each item is a PIL.Image.
For example, [PIL.Image0, PIL.Image1, PIL.Image2, ...]
return:
resized_imgs: List where each item is a PIL.Image after scaling.
"""
resized_imgs = []
for i in range(len(imgs)):
img = imgs[i]
w, h = img.size
if (w <= h and w == self.short_size) or (h <= w and
h == self.short_size):
resized_imgs.append(img)
continue
if w < h:
ow = self.short_size
oh = int(self.short_size * 4.0 / 3.0)
resized_imgs.append(img.resize((ow, oh), Image.BILINEAR))
else:
oh = self.short_size
ow = int(self.short_size * 4.0 / 3.0)
resized_imgs.append(img.resize((ow, oh), Image.BILINEAR))
return resized_imgs
class RandomCrop(object):
"""
Random crop images.
Args:
target_size(int): Random crop a square with the target_size from an image.
"""
def __init__(self, target_size):
self.target_size = target_size
def __call__(self, imgs):
"""
Performs random crop operations.
Args:
imgs: List where each item is a PIL.Image.
For example, [PIL.Image0, PIL.Image1, PIL.Image2, ...]
return:
crop_imgs: List where each item is a PIL.Image after random crop.
"""
w, h = imgs[0].size
th, tw = self.target_size, self.target_size
assert (w >= self.target_size) and (h >= self.target_size), \
"image width({}) and height({}) should be larger than crop size".format(
w, h, self.target_size)
crop_images = []
x1 = random.randint(0, w - tw)
y1 = random.randint(0, h - th)
for img in imgs:
if w == tw and h == th:
crop_images.append(img)
else:
crop_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
return crop_images
class CenterCrop(object):
"""
Center crop images.
Args:
target_size(int): Center crop a square with the target_size from an image.
"""
def __init__(self, target_size):
self.target_size = target_size
def __call__(self, imgs):
"""
Performs Center crop operations.
Args:
imgs: List where each item is a PIL.Image.
For example, [PIL.Image0, PIL.Image1, PIL.Image2, ...]
return:
ccrop_imgs: List where each item is a PIL.Image after Center crop.
"""
ccrop_imgs = []
for img in imgs:
w, h = img.size
th, tw = self.target_size, self.target_size
assert (w >= self.target_size) and (h >= self.target_size), \
"image width({}) and height({}) should be larger than crop size".format(
w, h, self.target_size)
x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.))
ccrop_imgs.append(img.crop((x1, y1, x1 + tw, y1 + th)))
return ccrop_imgs
class RandomFlip(object):
"""
Random Flip images.
Args:
p(float): Random flip images with the probability p.
"""
def __init__(self, p=0.5):
self.p = p
def __call__(self, imgs):
"""
Performs random flip operations.
Args:
imgs: List where each item is a PIL.Image.
For example, [PIL.Image0, PIL.Image1, PIL.Image2, ...]
return:
flip_imgs: List where each item is a PIL.Image after random flip.
"""
v = random.random()
if v < self.p:
flip_imgs = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in imgs]
return flip_imgs
else:
return imgs
class Image2Array(object):
"""
transfer PIL.Image to Numpy array and transpose dimensions from 'dhwc' to 'dchw'.
"""
def __init__(self):
self.format = "dhwc"
def __call__(self, imgs):
"""
Performs Image to NumpyArray operations.
Args:
imgs: List where each item is a PIL.Image.
For example, [PIL.Image0, PIL.Image1, PIL.Image2, ...]
return:
np_imgs: Numpy array.
"""
np_imgs = np.array(
[np.array(img).astype('float32') for img in imgs]) #dhwc
np_imgs = np_imgs.transpose(0, 3, 1, 2) #dchw
return np_imgs
class Normalization(object):
"""
Normalization.
Args:
mean(list[float]): mean values of different channels.
std(list[float]): std values of differetn channels.
"""
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, imgs):
"""
Performs normalization operations.
Args:
imgs: Numpy array.
return:
np_imgs: Numpy array after normalization.
"""
norm_imgs = imgs / 255
norm_imgs -= self.mean
norm_imgs /= self.std
return norm_imgs
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import numpy as np
import logging
from paddle.io import Dataset
from augmentations import *
from loader import *
logger = logging.getLogger(__name__)
class TSN_UCF101_Dataset(Dataset):
def __init__(self, cfg, mode):
self.mode = mode
self.format = cfg.MODEL.format #'videos' or 'frames'
self.seg_num = cfg.MODEL.seg_num
self.seglen = cfg.MODEL.seglen
self.short_size = cfg.TRAIN.short_size
self.target_size = cfg.TRAIN.target_size
self.img_mean = np.array(cfg.MODEL.image_mean).reshape(
[3, 1, 1]).astype(np.float32)
self.img_std = np.array(cfg.MODEL.image_std).reshape(
[3, 1, 1]).astype(np.float32)
self.filelist = cfg[mode.upper()]['filelist']
self._construct_loader()
def _construct_loader(self):
"""
Construct the video loader.
"""
self._num_retries = 5
self._path_to_videos = []
self._labels = []
self._num_frames = []
with open(self.filelist, "r") as f:
for clip_idx, path_label in enumerate(f.read().splitlines()):
if self.format == "videos":
path, label = path_label.split()
self._path_to_videos.append(path + '.avi')
self._num_frames.append(0) # unused
self._labels.append(int(label))
elif self.format == "frames":
path, num_frames, label = path_label.split()
self._path_to_videos.append(path)
self._num_frames.append(int(num_frames))
self._labels.append(int(label))
def __len__(self):
return len(self._path_to_videos)
def __getitem__(self, idx):
for ir in range(self._num_retries):
path = self._path_to_videos[idx]
num_frames = self._num_frames[idx]
try:
frames = self.pipline(
path,
num_frames,
format=self.format,
seg_num=self.seg_num,
seglen=self.seglen,
short_size=self.short_size,
target_size=self.target_size,
img_mean=self.img_mean,
img_std=self.img_std,
mode=self.mode)
except:
if ir < self._num_retries - 1:
logger.error(
'Error when loading {}, have {} trys, will try again'.
format(path, ir))
idx = random.randint(0, len(self._path_to_videos) - 1)
continue
else:
logger.error(
'Error when loading {}, have {} trys, will not try again'.
format(path, ir))
return None, None
label = self._labels[idx]
return frames, np.array([label]) #, np.array([idx])
def pipline(self, filepath, num_frames, format, seg_num, seglen, short_size,
target_size, img_mean, img_std, mode):
#Loader
if format == 'videos':
Loader_ops = [
VideoDecoder(filepath), VideoSampler(seg_num, seglen, mode)
]
elif format == 'frames':
Loader_ops = [
FrameLoader(filepath, num_frames, seg_num, seglen, mode)
]
#Augmentation
if mode == 'train':
Aug_ops = [
Scale(short_size), RandomCrop(target_size), RandomFlip(),
Image2Array(), Normalization(img_mean, img_std)
]
else:
Aug_ops = [
Scale(short_size), CenterCrop(target_size), Image2Array(),
Normalization(img_mean, img_std)
]
ops = Loader_ops + Aug_ops
data = ops[0]()
for op in ops[1:]:
data = op(data)
return data
...@@ -103,7 +103,7 @@ def parse_args(): ...@@ -103,7 +103,7 @@ def parse_args():
default='rawframes', default='rawframes',
choices=['rawframes', 'videos']) choices=['rawframes', 'videos'])
parser.add_argument('--out_list_path', type=str, default='./') parser.add_argument('--out_list_path', type=str, default='./')
parser.add_argument('--shuffle', action='store_true', default=True) parser.add_argument('--shuffle', action='store_true', default=False)
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -146,11 +146,12 @@ def main(): ...@@ -146,11 +146,12 @@ def main():
lists = build_split_list(split_tp[i], frame_info, shuffle=args.shuffle) lists = build_split_list(split_tp[i], frame_info, shuffle=args.shuffle)
filename = 'ucf101_train_split_{}_{}.txt'.format(i + 1, args.format) filename = 'ucf101_train_split_{}_{}.txt'.format(i + 1, args.format)
PATH = os.path.abspath(args.frame_path)
with open(os.path.join(out_path, filename), 'w') as f: with open(os.path.join(out_path, filename), 'w') as f:
f.writelines(lists[0]) f.writelines([os.path.join(PATH, item) for item in lists[0]])
filename = 'ucf101_val_split_{}_{}.txt'.format(i + 1, args.format) filename = 'ucf101_val_split_{}_{}.txt'.format(i + 1, args.format)
with open(os.path.join(out_path, filename), 'w') as f: with open(os.path.join(out_path, filename), 'w') as f:
f.writelines(lists[1]) f.writelines([os.path.join(PATH, item) for item in lists[1]])
if __name__ == "__main__": if __name__ == "__main__":
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import random
from PIL import Image
class VideoDecoder(object):
"""
Decode mp4 file to frames.
Args:
filepath: the file path of mp4 file
"""
def __init__(self, filepath):
self.filepath = filepath
def __call__(self):
"""
Perform mp4 decode operations.
return:
List where each item is a numpy array after decoder.
"""
cap = cv2.VideoCapture(self.filepath)
videolen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
sampledFrames = []
for i in range(videolen):
ret, frame = cap.read()
# maybe first frame is empty
if ret == False:
continue
img = frame[:, :, ::-1]
sampledFrames.append(img)
return sampledFrames
class VideoSampler(object):
"""
Sample frames.
Args:
num_seg(int): number of segments.
seg_len(int): number of sampled frames in each segment.
mode(str): 'train', 'test' or 'infer'
"""
def __init__(self, num_seg, seg_len, mode):
self.num_seg = num_seg
self.seg_len = seg_len
self.mode = mode
def __call__(self, frames):
"""
Args:
frames: List where each item is a numpy array decoding from video.
return:
List where each item is a PIL.Image after sampling.
"""
average_dur = int(len(frames) / self.num_seg)
imgs = []
for i in range(self.num_seg):
idx = 0
if self.mode == 'train':
if average_dur >= self.seg_len:
idx = random.randint(0, average_dur - self.seg_len)
idx += i * average_dur
elif average_dur >= 1:
idx += i * average_dur
else:
idx = i
else:
if average_dur >= self.seg_len:
idx = (average_dur - 1) // 2
idx += i * average_dur
elif average_dur >= 1:
idx += i * average_dur
else:
idx = i
for jj in range(idx, idx + self.seg_len):
imgbuf = frames[int(jj % len(frames))]
img = Image.fromarray(imgbuf, mode='RGB')
imgs.append(img)
return imgs
class FrameLoader(object):
"""
Load frames.
Args:
filepath(str): the file path of frames file.
num_frames(int): number of frames in a video file.
num_seg(int): number of segments.
seg_len(int): number of sampled frames in each segment.
mode(str): 'train', 'test' or 'infer'.
"""
def __init__(self, filepath, num_frames, num_seg, seg_len, mode):
self.filepath = filepath
self.num_frames = num_frames
self.num_seg = num_seg
self.seg_len = seg_len
self.mode = mode
def __call__(self):
"""
return:
imgs: List where each item is a PIL.Image.
"""
average_dur = int(self.num_frames / self.num_seg)
imgs = []
for i in range(self.num_seg):
idx = 0
if self.mode == 'train':
if average_dur >= self.seg_len:
idx = random.randint(0, average_dur - self.seg_len)
idx += i * average_dur
elif average_dur >= 1:
idx += i * average_dur
else:
idx = i
else:
if average_dur >= self.seg_len:
idx = (average_dur - 1) // 2
idx += i * average_dur
elif average_dur >= 1:
idx += i * average_dur
else:
idx = i
for jj in range(idx, idx + self.seg_len):
img = Image.open(
os.path.join(self.filepath, 'img_{:05d}.jpg'.format(
jj + 1))).convert('RGB')
imgs.append(img)
return imgs
...@@ -13,8 +13,6 @@ TRAIN: ...@@ -13,8 +13,6 @@ TRAIN:
epoch: 80 epoch: 80
short_size: 256 short_size: 256
target_size: 224 target_size: 224
num_reader_threads: 16
buf_size: 256
batch_size: 128 batch_size: 128
use_gpu: True use_gpu: True
filelist: "./data/dataset/ucf101/ucf101_train_split_1_rawframes.txt" filelist: "./data/dataset/ucf101/ucf101_train_split_1_rawframes.txt"
...@@ -24,19 +22,19 @@ TRAIN: ...@@ -24,19 +22,19 @@ TRAIN:
l2_weight_decay: 1e-4 l2_weight_decay: 1e-4
momentum: 0.9 momentum: 0.9
total_videos: 9738 total_videos: 9738
num_workers: 4
use_shuffle: True
VALID: VALID:
short_size: 256 short_size: 256
target_size: 224 target_size: 224
num_reader_threads: 16
buf_size: 256
batch_size: 128 batch_size: 128
filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt" filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt"
num_workers: 4
TEST: TEST:
short_size: 256 short_size: 256
target_size: 224 target_size: 224
num_reader_threads: 16
buf_size: 256
batch_size: 128 batch_size: 128
filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt" filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt"
num_workers: 4
...@@ -13,8 +13,6 @@ TRAIN: ...@@ -13,8 +13,6 @@ TRAIN:
epoch: 80 epoch: 80
short_size: 256 short_size: 256
target_size: 224 target_size: 224
num_reader_threads: 16
buf_size: 256
batch_size: 128 batch_size: 128
use_gpu: True use_gpu: True
filelist: "./data/dataset/ucf101/ucf101_train_split_1_videos.txt" filelist: "./data/dataset/ucf101/ucf101_train_split_1_videos.txt"
...@@ -24,19 +22,19 @@ TRAIN: ...@@ -24,19 +22,19 @@ TRAIN:
l2_weight_decay: 1e-4 l2_weight_decay: 1e-4
momentum: 0.9 momentum: 0.9
total_videos: 9738 total_videos: 9738
num_workers: 4
use_shuffle: True
VALID: VALID:
short_size: 256 short_size: 256
target_size: 224 target_size: 224
num_reader_threads: 16
buf_size: 256
batch_size: 128 batch_size: 128
filelist: "./data/dataset/ucf101/ucf101_val_split_1_videos.txt" filelist: "./data/dataset/ucf101/ucf101_val_split_1_videos.txt"
num_workers: 4
TEST: TEST:
short_size: 256 short_size: 256
target_size: 224 target_size: 224
num_reader_threads: 16
buf_size: 256
batch_size: 128 batch_size: 128
filelist: "./data/dataset/ucf101/ucf101_val_split_1_videos.txt" filelist: "./data/dataset/ucf101/ucf101_val_split_1_videos.txt"
num_workers: 4
...@@ -13,8 +13,6 @@ TRAIN: ...@@ -13,8 +13,6 @@ TRAIN:
epoch: 80 epoch: 80
short_size: 256 short_size: 256
target_size: 224 target_size: 224
num_reader_threads: 8
buf_size: 64
batch_size: 32 batch_size: 32
use_gpu: True use_gpu: True
filelist: "./data/dataset/ucf101/ucf101_train_split_1_rawframes.txt" filelist: "./data/dataset/ucf101/ucf101_train_split_1_rawframes.txt"
...@@ -24,19 +22,19 @@ TRAIN: ...@@ -24,19 +22,19 @@ TRAIN:
l2_weight_decay: 1e-4 l2_weight_decay: 1e-4
momentum: 0.9 momentum: 0.9
total_videos: 9738 total_videos: 9738
num_workers: 4
use_shuffle: True
VALID: VALID:
short_size: 256 short_size: 256
target_size: 224 target_size: 224
num_reader_threads: 8
buf_size: 64
batch_size: 32 batch_size: 32
filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt" filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt"
num_workers: 4
TEST: TEST:
short_size: 256 short_size: 256
target_size: 224 target_size: 224
num_reader_threads: 8
buf_size: 64
batch_size: 32 batch_size: 32
filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt" filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt"
num_workers: 4
...@@ -13,8 +13,6 @@ TRAIN: ...@@ -13,8 +13,6 @@ TRAIN:
epoch: 80 epoch: 80
short_size: 256 short_size: 256
target_size: 224 target_size: 224
num_reader_threads: 8
buf_size: 64
batch_size: 32 batch_size: 32
use_gpu: True use_gpu: True
filelist: "./data/dataset/ucf101/ucf101_train_split_1_videos.txt" filelist: "./data/dataset/ucf101/ucf101_train_split_1_videos.txt"
...@@ -24,19 +22,19 @@ TRAIN: ...@@ -24,19 +22,19 @@ TRAIN:
l2_weight_decay: 1e-4 l2_weight_decay: 1e-4
momentum: 0.9 momentum: 0.9
total_videos: 9738 total_videos: 9738
num_workers: 4
use_shuffle: True
VALID: VALID:
short_size: 256 short_size: 256
target_size: 224 target_size: 224
num_reader_threads: 8
buf_size: 64
batch_size: 32 batch_size: 32
filelist: "./data/dataset/ucf101/ucf101_val_split_1_videos.txt" filelist: "./data/dataset/ucf101/ucf101_val_split_1_videos.txt"
num_workers: 4
TEST: TEST:
short_size: 256 short_size: 256
target_size: 224 target_size: 224
num_reader_threads: 8
buf_size: 64
batch_size: 32 batch_size: 32
filelist: "./data/dataset/ucf101/ucf101_val_split_1_videos.txt" filelist: "./data/dataset/ucf101/ucf101_val_split_1_videos.txt"
num_workers: 4
...@@ -27,6 +27,9 @@ from paddle.fluid.dygraph.base import to_variable ...@@ -27,6 +27,9 @@ from paddle.fluid.dygraph.base import to_variable
from model import TSN_ResNet from model import TSN_ResNet
from utils.config_utils import * from utils.config_utils import *
from reader.ucf101_reader import UCF101Reader from reader.ucf101_reader import UCF101Reader
import paddle
from paddle.io import DataLoader, DistributedBatchSampler
from compose import TSN_UCF101_Dataset
logging.root.handlers = [] logging.root.handlers = []
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
...@@ -111,19 +114,15 @@ def init_model(model, pre_state_dict): ...@@ -111,19 +114,15 @@ def init_model(model, pre_state_dict):
return model return model
def val(epoch, model, cfg, args): def val(epoch, model, val_loader, cfg, args):
reader = UCF101Reader(name="TSN", mode="valid", cfg=cfg)
reader = reader.create_reader()
total_loss = 0.0 total_loss = 0.0
total_acc1 = 0.0 total_acc1 = 0.0
total_acc5 = 0.0 total_acc5 = 0.0
total_sample = 0 total_sample = 0
for batch_id, data in enumerate(reader()): for batch_id, data in enumerate(val_loader):
x_data = np.array([item[0] for item in data]) imgs = paddle.to_tensor(data[0])
y_data = np.array([item[1] for item in data]).reshape([-1, 1]) labels = paddle.to_tensor(data[1])
imgs = to_variable(x_data)
labels = to_variable(y_data)
labels.stop_gradient = True labels.stop_gradient = True
outputs = model(imgs) outputs = model(imgs)
...@@ -210,11 +209,30 @@ def train(args): ...@@ -210,11 +209,30 @@ def train(args):
gpus = gpus.split(",") gpus = gpus.split(",")
num_gpus = len(gpus) num_gpus = len(gpus)
bs_denominator = num_gpus bs_denominator = num_gpus
train_config.TRAIN.batch_size = int(train_config.TRAIN.batch_size / bs_train_single = int(train_config.TRAIN.batch_size / bs_denominator)
bs_denominator) bs_val_single = int(valid_config.VALID.batch_size / bs_denominator)
train_reader = UCF101Reader(name="TSN", mode="train", cfg=train_config) train_dataset = TSN_UCF101_Dataset(train_config, 'train')
train_reader = train_reader.create_reader() val_dataset = TSN_UCF101_Dataset(valid_config, 'valid')
train_sampler = DistributedBatchSampler(
train_dataset,
batch_size=bs_train_single,
shuffle=train_config.TRAIN.use_shuffle,
drop_last=True)
train_loader = DataLoader(
train_dataset,
batch_sampler=train_sampler,
places=place,
num_workers=train_config.TRAIN.num_workers,
return_list=True)
val_sampler = DistributedBatchSampler(
val_dataset, batch_size=bs_val_single)
val_loader = DataLoader(
val_dataset,
batch_sampler=val_sampler,
places=place,
num_workers=valid_config.VALID.num_workers,
return_list=True)
if use_data_parallel: if use_data_parallel:
# (data_parallel step4/6) # (data_parallel step4/6)
...@@ -234,12 +252,10 @@ def train(args): ...@@ -234,12 +252,10 @@ def train(args):
total_acc5 = 0.0 total_acc5 = 0.0
total_sample = 0 total_sample = 0
batch_start = time.time() batch_start = time.time()
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_loader):
train_reader_cost = time.time() - batch_start train_reader_cost = time.time() - batch_start
x_data = np.array([item[0] for item in data]).astype("float32") imgs = paddle.to_tensor(data[0])
y_data = np.array([item[1] for item in data]).reshape([-1, 1]) labels = paddle.to_tensor(data[1])
imgs = to_variable(x_data)
labels = to_variable(y_data)
labels.stop_gradient = True labels.stop_gradient = True
outputs = video_model(imgs) outputs = video_model(imgs)
...@@ -292,13 +308,13 @@ def train(args): ...@@ -292,13 +308,13 @@ def train(args):
model_path = os.path.join( model_path = os.path.join(
args.checkpoint, args.checkpoint,
"_" + model_path_pre + "_epoch{}".format(epoch)) "_" + model_path_pre + "_epoch{}".format(epoch))
fluid.dygraph.save_dygraph( fluid.dygraph.save_dygraph(video_model.state_dict(), model_path)
video_model.state_dict(), model_path)
fluid.dygraph.save_dygraph(optimizer.state_dict(), model_path) fluid.dygraph.save_dygraph(optimizer.state_dict(), model_path)
if args.validate: if args.validate:
video_model.eval() video_model.eval()
val_acc = val(epoch, video_model, valid_config, args) val_acc = val(epoch, video_model, val_loader, valid_config,
args)
# save the best parameters in trainging stage # save the best parameters in trainging stage
if epoch == 1: if epoch == 1:
best_acc = val_acc best_acc = val_acc
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册