From 4d1187d57a43412e25b50972aa545dd1b13268c0 Mon Sep 17 00:00:00 2001 From: huangjun12 <2399845970@qq.com> Date: Wed, 16 Sep 2020 20:36:02 +0800 Subject: [PATCH] update tsn Reader using dataloader and pipline (#4856) --- dygraph/tsn/augmentations.py | 209 ++++++++++++++++++ dygraph/tsn/compose.py | 125 +++++++++++ .../dataset/ucf101/build_ucf101_file_list.py | 7 +- dygraph/tsn/loader.py | 149 +++++++++++++ dygraph/tsn/multi_tsn_frame.yaml | 10 +- dygraph/tsn/multi_tsn_video.yaml | 14 +- dygraph/tsn/single_tsn_frame.yaml | 12 +- dygraph/tsn/single_tsn_video.yaml | 12 +- dygraph/tsn/train.py | 58 +++-- 9 files changed, 544 insertions(+), 52 deletions(-) create mode 100644 dygraph/tsn/augmentations.py create mode 100644 dygraph/tsn/compose.py create mode 100644 dygraph/tsn/loader.py diff --git a/dygraph/tsn/augmentations.py b/dygraph/tsn/augmentations.py new file mode 100644 index 00000000..ac76f0cf --- /dev/null +++ b/dygraph/tsn/augmentations.py @@ -0,0 +1,209 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import numpy as np +from PIL import Image + + +class Scale(object): + """ + Scale images. + + Args: + short_size(float | int): Short size of an image will be scaled to the short_size. + """ + + def __init__(self, short_size): + self.short_size = short_size + + def __call__(self, imgs): + """ + Performs resize operations. + Args: + imgs: List where each item is a PIL.Image. + For example, [PIL.Image0, PIL.Image1, PIL.Image2, ...] + return: + resized_imgs: List where each item is a PIL.Image after scaling. + """ + resized_imgs = [] + for i in range(len(imgs)): + img = imgs[i] + w, h = img.size + if (w <= h and w == self.short_size) or (h <= w and + h == self.short_size): + resized_imgs.append(img) + continue + + if w < h: + ow = self.short_size + oh = int(self.short_size * 4.0 / 3.0) + resized_imgs.append(img.resize((ow, oh), Image.BILINEAR)) + else: + oh = self.short_size + ow = int(self.short_size * 4.0 / 3.0) + resized_imgs.append(img.resize((ow, oh), Image.BILINEAR)) + + return resized_imgs + + +class RandomCrop(object): + """ + Random crop images. + + Args: + target_size(int): Random crop a square with the target_size from an image. + """ + + def __init__(self, target_size): + self.target_size = target_size + + def __call__(self, imgs): + """ + Performs random crop operations. + Args: + imgs: List where each item is a PIL.Image. + For example, [PIL.Image0, PIL.Image1, PIL.Image2, ...] + return: + crop_imgs: List where each item is a PIL.Image after random crop. + """ + w, h = imgs[0].size + th, tw = self.target_size, self.target_size + + assert (w >= self.target_size) and (h >= self.target_size), \ + "image width({}) and height({}) should be larger than crop size".format( + w, h, self.target_size) + + crop_images = [] + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) + + for img in imgs: + if w == tw and h == th: + crop_images.append(img) + else: + crop_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) + + return crop_images + + +class CenterCrop(object): + """ + Center crop images. + + Args: + target_size(int): Center crop a square with the target_size from an image. + """ + + def __init__(self, target_size): + self.target_size = target_size + + def __call__(self, imgs): + """ + Performs Center crop operations. + Args: + imgs: List where each item is a PIL.Image. + For example, [PIL.Image0, PIL.Image1, PIL.Image2, ...] + return: + ccrop_imgs: List where each item is a PIL.Image after Center crop. + """ + ccrop_imgs = [] + for img in imgs: + w, h = img.size + th, tw = self.target_size, self.target_size + assert (w >= self.target_size) and (h >= self.target_size), \ + "image width({}) and height({}) should be larger than crop size".format( + w, h, self.target_size) + x1 = int(round((w - tw) / 2.)) + y1 = int(round((h - th) / 2.)) + ccrop_imgs.append(img.crop((x1, y1, x1 + tw, y1 + th))) + + return ccrop_imgs + + +class RandomFlip(object): + """ + Random Flip images. + + Args: + p(float): Random flip images with the probability p. + """ + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, imgs): + """ + Performs random flip operations. + Args: + imgs: List where each item is a PIL.Image. + For example, [PIL.Image0, PIL.Image1, PIL.Image2, ...] + return: + flip_imgs: List where each item is a PIL.Image after random flip. + """ + v = random.random() + if v < self.p: + flip_imgs = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in imgs] + return flip_imgs + else: + return imgs + + +class Image2Array(object): + """ + transfer PIL.Image to Numpy array and transpose dimensions from 'dhwc' to 'dchw'. + """ + + def __init__(self): + self.format = "dhwc" + + def __call__(self, imgs): + """ + Performs Image to NumpyArray operations. + Args: + imgs: List where each item is a PIL.Image. + For example, [PIL.Image0, PIL.Image1, PIL.Image2, ...] + return: + np_imgs: Numpy array. + """ + np_imgs = np.array( + [np.array(img).astype('float32') for img in imgs]) #dhwc + np_imgs = np_imgs.transpose(0, 3, 1, 2) #dchw + return np_imgs + + +class Normalization(object): + """ + Normalization. + Args: + mean(list[float]): mean values of different channels. + std(list[float]): std values of differetn channels. + """ + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, imgs): + """ + Performs normalization operations. + Args: + imgs: Numpy array. + return: + np_imgs: Numpy array after normalization. + """ + norm_imgs = imgs / 255 + norm_imgs -= self.mean + norm_imgs /= self.std + return norm_imgs diff --git a/dygraph/tsn/compose.py b/dygraph/tsn/compose.py new file mode 100644 index 00000000..81198361 --- /dev/null +++ b/dygraph/tsn/compose.py @@ -0,0 +1,125 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import numpy as np +import logging +from paddle.io import Dataset +from augmentations import * +from loader import * + +logger = logging.getLogger(__name__) + + +class TSN_UCF101_Dataset(Dataset): + def __init__(self, cfg, mode): + self.mode = mode + self.format = cfg.MODEL.format #'videos' or 'frames' + self.seg_num = cfg.MODEL.seg_num + self.seglen = cfg.MODEL.seglen + self.short_size = cfg.TRAIN.short_size + self.target_size = cfg.TRAIN.target_size + self.img_mean = np.array(cfg.MODEL.image_mean).reshape( + [3, 1, 1]).astype(np.float32) + self.img_std = np.array(cfg.MODEL.image_std).reshape( + [3, 1, 1]).astype(np.float32) + + self.filelist = cfg[mode.upper()]['filelist'] + + self._construct_loader() + + def _construct_loader(self): + """ + Construct the video loader. + """ + self._num_retries = 5 + self._path_to_videos = [] + self._labels = [] + self._num_frames = [] + with open(self.filelist, "r") as f: + for clip_idx, path_label in enumerate(f.read().splitlines()): + if self.format == "videos": + path, label = path_label.split() + self._path_to_videos.append(path + '.avi') + self._num_frames.append(0) # unused + self._labels.append(int(label)) + elif self.format == "frames": + path, num_frames, label = path_label.split() + self._path_to_videos.append(path) + self._num_frames.append(int(num_frames)) + self._labels.append(int(label)) + + def __len__(self): + return len(self._path_to_videos) + + def __getitem__(self, idx): + for ir in range(self._num_retries): + path = self._path_to_videos[idx] + num_frames = self._num_frames[idx] + try: + frames = self.pipline( + path, + num_frames, + format=self.format, + seg_num=self.seg_num, + seglen=self.seglen, + short_size=self.short_size, + target_size=self.target_size, + img_mean=self.img_mean, + img_std=self.img_std, + mode=self.mode) + except: + if ir < self._num_retries - 1: + logger.error( + 'Error when loading {}, have {} trys, will try again'. + format(path, ir)) + idx = random.randint(0, len(self._path_to_videos) - 1) + continue + else: + logger.error( + 'Error when loading {}, have {} trys, will not try again'. + format(path, ir)) + return None, None + label = self._labels[idx] + return frames, np.array([label]) #, np.array([idx]) + + def pipline(self, filepath, num_frames, format, seg_num, seglen, short_size, + target_size, img_mean, img_std, mode): + #Loader + if format == 'videos': + Loader_ops = [ + VideoDecoder(filepath), VideoSampler(seg_num, seglen, mode) + ] + elif format == 'frames': + Loader_ops = [ + FrameLoader(filepath, num_frames, seg_num, seglen, mode) + ] + + #Augmentation + if mode == 'train': + Aug_ops = [ + Scale(short_size), RandomCrop(target_size), RandomFlip(), + Image2Array(), Normalization(img_mean, img_std) + ] + else: + Aug_ops = [ + Scale(short_size), CenterCrop(target_size), Image2Array(), + Normalization(img_mean, img_std) + ] + + ops = Loader_ops + Aug_ops + data = ops[0]() + for op in ops[1:]: + data = op(data) + return data diff --git a/dygraph/tsn/data/dataset/ucf101/build_ucf101_file_list.py b/dygraph/tsn/data/dataset/ucf101/build_ucf101_file_list.py index 0ade23f0..ecf281b5 100644 --- a/dygraph/tsn/data/dataset/ucf101/build_ucf101_file_list.py +++ b/dygraph/tsn/data/dataset/ucf101/build_ucf101_file_list.py @@ -103,7 +103,7 @@ def parse_args(): default='rawframes', choices=['rawframes', 'videos']) parser.add_argument('--out_list_path', type=str, default='./') - parser.add_argument('--shuffle', action='store_true', default=True) + parser.add_argument('--shuffle', action='store_true', default=False) args = parser.parse_args() return args @@ -146,11 +146,12 @@ def main(): lists = build_split_list(split_tp[i], frame_info, shuffle=args.shuffle) filename = 'ucf101_train_split_{}_{}.txt'.format(i + 1, args.format) + PATH = os.path.abspath(args.frame_path) with open(os.path.join(out_path, filename), 'w') as f: - f.writelines(lists[0]) + f.writelines([os.path.join(PATH, item) for item in lists[0]]) filename = 'ucf101_val_split_{}_{}.txt'.format(i + 1, args.format) with open(os.path.join(out_path, filename), 'w') as f: - f.writelines(lists[1]) + f.writelines([os.path.join(PATH, item) for item in lists[1]]) if __name__ == "__main__": diff --git a/dygraph/tsn/loader.py b/dygraph/tsn/loader.py new file mode 100644 index 00000000..3e80c4a6 --- /dev/null +++ b/dygraph/tsn/loader.py @@ -0,0 +1,149 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +import random +from PIL import Image + + +class VideoDecoder(object): + """ + Decode mp4 file to frames. + Args: + filepath: the file path of mp4 file + """ + + def __init__(self, filepath): + self.filepath = filepath + + def __call__(self): + """ + Perform mp4 decode operations. + return: + List where each item is a numpy array after decoder. + """ + cap = cv2.VideoCapture(self.filepath) + videolen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + sampledFrames = [] + for i in range(videolen): + ret, frame = cap.read() + # maybe first frame is empty + if ret == False: + continue + img = frame[:, :, ::-1] + sampledFrames.append(img) + return sampledFrames + + +class VideoSampler(object): + """ + Sample frames. + Args: + num_seg(int): number of segments. + seg_len(int): number of sampled frames in each segment. + mode(str): 'train', 'test' or 'infer' + + """ + + def __init__(self, num_seg, seg_len, mode): + self.num_seg = num_seg + self.seg_len = seg_len + self.mode = mode + + def __call__(self, frames): + """ + Args: + frames: List where each item is a numpy array decoding from video. + return: + List where each item is a PIL.Image after sampling. + """ + average_dur = int(len(frames) / self.num_seg) + imgs = [] + for i in range(self.num_seg): + idx = 0 + if self.mode == 'train': + if average_dur >= self.seg_len: + idx = random.randint(0, average_dur - self.seg_len) + idx += i * average_dur + elif average_dur >= 1: + idx += i * average_dur + else: + idx = i + else: + if average_dur >= self.seg_len: + idx = (average_dur - 1) // 2 + idx += i * average_dur + elif average_dur >= 1: + idx += i * average_dur + else: + idx = i + + for jj in range(idx, idx + self.seg_len): + imgbuf = frames[int(jj % len(frames))] + img = Image.fromarray(imgbuf, mode='RGB') + imgs.append(img) + return imgs + + +class FrameLoader(object): + """ + Load frames. + Args: + filepath(str): the file path of frames file. + num_frames(int): number of frames in a video file. + num_seg(int): number of segments. + seg_len(int): number of sampled frames in each segment. + mode(str): 'train', 'test' or 'infer'. + """ + + def __init__(self, filepath, num_frames, num_seg, seg_len, mode): + self.filepath = filepath + self.num_frames = num_frames + self.num_seg = num_seg + self.seg_len = seg_len + self.mode = mode + + def __call__(self): + """ + return: + imgs: List where each item is a PIL.Image. + """ + average_dur = int(self.num_frames / self.num_seg) + imgs = [] + for i in range(self.num_seg): + idx = 0 + if self.mode == 'train': + if average_dur >= self.seg_len: + idx = random.randint(0, average_dur - self.seg_len) + idx += i * average_dur + elif average_dur >= 1: + idx += i * average_dur + else: + idx = i + else: + if average_dur >= self.seg_len: + idx = (average_dur - 1) // 2 + idx += i * average_dur + elif average_dur >= 1: + idx += i * average_dur + else: + idx = i + + for jj in range(idx, idx + self.seg_len): + img = Image.open( + os.path.join(self.filepath, 'img_{:05d}.jpg'.format( + jj + 1))).convert('RGB') + imgs.append(img) + return imgs diff --git a/dygraph/tsn/multi_tsn_frame.yaml b/dygraph/tsn/multi_tsn_frame.yaml index 9cc76a4d..a4aa69b9 100644 --- a/dygraph/tsn/multi_tsn_frame.yaml +++ b/dygraph/tsn/multi_tsn_frame.yaml @@ -13,8 +13,6 @@ TRAIN: epoch: 80 short_size: 256 target_size: 224 - num_reader_threads: 16 - buf_size: 256 batch_size: 128 use_gpu: True filelist: "./data/dataset/ucf101/ucf101_train_split_1_rawframes.txt" @@ -24,19 +22,19 @@ TRAIN: l2_weight_decay: 1e-4 momentum: 0.9 total_videos: 9738 + num_workers: 4 + use_shuffle: True VALID: short_size: 256 target_size: 224 - num_reader_threads: 16 - buf_size: 256 batch_size: 128 filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt" + num_workers: 4 TEST: short_size: 256 target_size: 224 - num_reader_threads: 16 - buf_size: 256 batch_size: 128 filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt" + num_workers: 4 diff --git a/dygraph/tsn/multi_tsn_video.yaml b/dygraph/tsn/multi_tsn_video.yaml index f664d359..701e5953 100644 --- a/dygraph/tsn/multi_tsn_video.yaml +++ b/dygraph/tsn/multi_tsn_video.yaml @@ -13,8 +13,6 @@ TRAIN: epoch: 80 short_size: 256 target_size: 224 - num_reader_threads: 16 - buf_size: 256 batch_size: 128 use_gpu: True filelist: "./data/dataset/ucf101/ucf101_train_split_1_videos.txt" @@ -24,19 +22,19 @@ TRAIN: l2_weight_decay: 1e-4 momentum: 0.9 total_videos: 9738 + num_workers: 4 + use_shuffle: True VALID: short_size: 256 target_size: 224 - num_reader_threads: 16 - buf_size: 256 batch_size: 128 filelist: "./data/dataset/ucf101/ucf101_val_split_1_videos.txt" + num_workers: 4 TEST: short_size: 256 target_size: 224 - num_reader_threads: 16 - buf_size: 256 - batch_size: 128 - filelist: "./data/dataset/ucf101/ucf101_val_split_1_videos.txt" \ No newline at end of file + batch_size: 128 + filelist: "./data/dataset/ucf101/ucf101_val_split_1_videos.txt" + num_workers: 4 diff --git a/dygraph/tsn/single_tsn_frame.yaml b/dygraph/tsn/single_tsn_frame.yaml index e2b7eff9..28e5774a 100644 --- a/dygraph/tsn/single_tsn_frame.yaml +++ b/dygraph/tsn/single_tsn_frame.yaml @@ -13,8 +13,6 @@ TRAIN: epoch: 80 short_size: 256 target_size: 224 - num_reader_threads: 8 - buf_size: 64 batch_size: 32 use_gpu: True filelist: "./data/dataset/ucf101/ucf101_train_split_1_rawframes.txt" @@ -24,19 +22,19 @@ TRAIN: l2_weight_decay: 1e-4 momentum: 0.9 total_videos: 9738 + num_workers: 4 + use_shuffle: True VALID: short_size: 256 target_size: 224 - num_reader_threads: 8 - buf_size: 64 batch_size: 32 filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt" + num_workers: 4 TEST: short_size: 256 target_size: 224 - num_reader_threads: 8 - buf_size: 64 batch_size: 32 - filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt" \ No newline at end of file + filelist: "./data/dataset/ucf101/ucf101_val_split_1_rawframes.txt" + num_workers: 4 diff --git a/dygraph/tsn/single_tsn_video.yaml b/dygraph/tsn/single_tsn_video.yaml index c5c032ed..402d3cf4 100644 --- a/dygraph/tsn/single_tsn_video.yaml +++ b/dygraph/tsn/single_tsn_video.yaml @@ -13,8 +13,6 @@ TRAIN: epoch: 80 short_size: 256 target_size: 224 - num_reader_threads: 8 - buf_size: 64 batch_size: 32 use_gpu: True filelist: "./data/dataset/ucf101/ucf101_train_split_1_videos.txt" @@ -24,19 +22,19 @@ TRAIN: l2_weight_decay: 1e-4 momentum: 0.9 total_videos: 9738 + num_workers: 4 + use_shuffle: True VALID: short_size: 256 target_size: 224 - num_reader_threads: 8 - buf_size: 64 batch_size: 32 filelist: "./data/dataset/ucf101/ucf101_val_split_1_videos.txt" + num_workers: 4 TEST: short_size: 256 target_size: 224 - num_reader_threads: 8 - buf_size: 64 batch_size: 32 - filelist: "./data/dataset/ucf101/ucf101_val_split_1_videos.txt" \ No newline at end of file + filelist: "./data/dataset/ucf101/ucf101_val_split_1_videos.txt" + num_workers: 4 diff --git a/dygraph/tsn/train.py b/dygraph/tsn/train.py index c40e63cb..1cfcf27b 100644 --- a/dygraph/tsn/train.py +++ b/dygraph/tsn/train.py @@ -27,6 +27,9 @@ from paddle.fluid.dygraph.base import to_variable from model import TSN_ResNet from utils.config_utils import * from reader.ucf101_reader import UCF101Reader +import paddle +from paddle.io import DataLoader, DistributedBatchSampler +from compose import TSN_UCF101_Dataset logging.root.handlers = [] FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s' @@ -111,19 +114,15 @@ def init_model(model, pre_state_dict): return model -def val(epoch, model, cfg, args): - reader = UCF101Reader(name="TSN", mode="valid", cfg=cfg) - reader = reader.create_reader() +def val(epoch, model, val_loader, cfg, args): total_loss = 0.0 total_acc1 = 0.0 total_acc5 = 0.0 total_sample = 0 - for batch_id, data in enumerate(reader()): - x_data = np.array([item[0] for item in data]) - y_data = np.array([item[1] for item in data]).reshape([-1, 1]) - imgs = to_variable(x_data) - labels = to_variable(y_data) + for batch_id, data in enumerate(val_loader): + imgs = paddle.to_tensor(data[0]) + labels = paddle.to_tensor(data[1]) labels.stop_gradient = True outputs = model(imgs) @@ -210,11 +209,30 @@ def train(args): gpus = gpus.split(",") num_gpus = len(gpus) bs_denominator = num_gpus - train_config.TRAIN.batch_size = int(train_config.TRAIN.batch_size / - bs_denominator) - - train_reader = UCF101Reader(name="TSN", mode="train", cfg=train_config) - train_reader = train_reader.create_reader() + bs_train_single = int(train_config.TRAIN.batch_size / bs_denominator) + bs_val_single = int(valid_config.VALID.batch_size / bs_denominator) + + train_dataset = TSN_UCF101_Dataset(train_config, 'train') + val_dataset = TSN_UCF101_Dataset(valid_config, 'valid') + train_sampler = DistributedBatchSampler( + train_dataset, + batch_size=bs_train_single, + shuffle=train_config.TRAIN.use_shuffle, + drop_last=True) + train_loader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + places=place, + num_workers=train_config.TRAIN.num_workers, + return_list=True) + val_sampler = DistributedBatchSampler( + val_dataset, batch_size=bs_val_single) + val_loader = DataLoader( + val_dataset, + batch_sampler=val_sampler, + places=place, + num_workers=valid_config.VALID.num_workers, + return_list=True) if use_data_parallel: # (data_parallel step4/6) @@ -234,12 +252,10 @@ def train(args): total_acc5 = 0.0 total_sample = 0 batch_start = time.time() - for batch_id, data in enumerate(train_reader()): + for batch_id, data in enumerate(train_loader): train_reader_cost = time.time() - batch_start - x_data = np.array([item[0] for item in data]).astype("float32") - y_data = np.array([item[1] for item in data]).reshape([-1, 1]) - imgs = to_variable(x_data) - labels = to_variable(y_data) + imgs = paddle.to_tensor(data[0]) + labels = paddle.to_tensor(data[1]) labels.stop_gradient = True outputs = video_model(imgs) @@ -292,13 +308,13 @@ def train(args): model_path = os.path.join( args.checkpoint, "_" + model_path_pre + "_epoch{}".format(epoch)) - fluid.dygraph.save_dygraph( - video_model.state_dict(), model_path) + fluid.dygraph.save_dygraph(video_model.state_dict(), model_path) fluid.dygraph.save_dygraph(optimizer.state_dict(), model_path) if args.validate: video_model.eval() - val_acc = val(epoch, video_model, valid_config, args) + val_acc = val(epoch, video_model, val_loader, valid_config, + args) # save the best parameters in trainging stage if epoch == 1: best_acc = val_acc -- GitLab