diff --git a/deploy/pphuman/config/infer_cfg.yml b/deploy/pphuman/config/infer_cfg.yml index fb93ecc1eb9f62d2a6e8a271d661685a8affb6f6..31b347ec01910cd14263a9d0adda283c76fde0c0 100644 --- a/deploy/pphuman/config/infer_cfg.yml +++ b/deploy/pphuman/config/infer_cfg.yml @@ -25,8 +25,13 @@ ATTR: enable: False VIDEO_ACTION: - model_dir: output_inference/pp-stm + model_dir: output_inference/ppTSM batch_size: 1 + frame_len: 8 + sample_freq: 7 + short_size: 340 + target_size: 320 + basemode: "videobased" enable: False SKELETON_ACTION: diff --git a/deploy/pphuman/datacollector.py b/deploy/pphuman/datacollector.py index d636f0cef85e6c6731988005a1ca3f87dc32c507..22f0cc116aab409ef155f6eb14662e3be36f4ea4 100644 --- a/deploy/pphuman/datacollector.py +++ b/deploy/pphuman/datacollector.py @@ -23,6 +23,7 @@ class Result(object): 'mot': dict(), 'attr': dict(), 'kpt': dict(), + 'video_action': dict(), 'skeleton_action': dict(), 'reid': dict() } diff --git a/deploy/pphuman/pipe_utils.py b/deploy/pphuman/pipe_utils.py index b910a5cbb32378baf642c5b937f7d5675a65fdd3..a8968543c88a79adf5163e6579a123fa8bf8ddc2 100644 --- a/deploy/pphuman/pipe_utils.py +++ b/deploy/pphuman/pipe_utils.py @@ -152,6 +152,7 @@ class PipeTimer(Times): 'mot': Times(), 'attr': Times(), 'kpt': Times(), + 'video_action': Times(), 'skeleton_action': Times(), 'reid': Times() } @@ -197,6 +198,7 @@ class PipeTimer(Times): dic['kpt'] = round(self.module_time['kpt'].value() / max(1, self.img_num), 4) if average else self.module_time['kpt'].value() + dic['video_action'] = self.module_time['video_action'].value() dic['skeleton_action'] = round( self.module_time['skeleton_action'].value() / max(1, self.img_num), 4) if average else self.module_time['skeleton_action'].value() diff --git a/deploy/pphuman/pipeline.py b/deploy/pphuman/pipeline.py index 2eb3afc36fe2822f390e23552d10a900980b5bd9..47928b7e859b5bbbbc0abdd4de60a0c48154a8c5 100644 --- a/deploy/pphuman/pipeline.py +++ b/deploy/pphuman/pipeline.py @@ -36,6 +36,7 @@ from python.infer import Detector, DetectorPicoDet from python.attr_infer import AttrDetector from python.keypoint_infer import KeyPointDetector from python.keypoint_postprocess import translate_to_ori_images +from python.video_action_infer import VideoActionRecognizer from python.action_infer import SkeletonActionRecognizer from python.action_utils import KeyPointBuff, SkeletonActionVisualHelper @@ -75,7 +76,7 @@ class Pipeline(object): draw_center_traj (bool): Whether drawing the trajectory of center, default as False secs_interval (int): The seconds interval to count after tracking, default as 10 do_entrance_counting(bool): Whether counting the numbers of identifiers entering - or getting out from the entrance, default as False,only support single class + or getting out from the entrance, default as False, only support single class counting in MOT. """ @@ -181,7 +182,7 @@ class Pipeline(object): else: raise ValueError( - "Illegal Input, please set one of ['video_file','camera_id','image_file', 'image_dir']" + "Illegal Input, please set one of ['video_file', 'camera_id', 'image_file', 'image_dir']" ) return input @@ -218,6 +219,7 @@ class PipePredictor(object): 1. Tracking 2. Tracking -> Attribute 3. Tracking -> KeyPoint -> SkeletonAction Recognition + 4. VideoAction Recognition Args: cfg (dict): config of models in pipeline @@ -240,7 +242,7 @@ class PipePredictor(object): draw_center_traj (bool): Whether drawing the trajectory of center, default as False secs_interval (int): The seconds interval to count after tracking, default as 10 do_entrance_counting(bool): Whether counting the numbers of identifiers entering - or getting out from the entrance, default as False,only support single class + or getting out from the entrance, default as False, only support single class counting in MOT. """ @@ -277,6 +279,7 @@ class PipePredictor(object): 'ID_BASED_CLSACTION', False) else False self.with_mtmct = cfg.get('REID', False)['enable'] if cfg.get( 'REID', False) else False + if self.with_attr: print('Attribute Recognition enabled') if self.with_skeleton_action: @@ -296,6 +299,7 @@ class PipePredictor(object): "idbased": False, "skeletonbased": False } + self.is_video = is_video self.multi_camera = multi_camera self.cfg = cfg @@ -416,6 +420,31 @@ class PipePredictor(object): use_dark=False) self.kpt_buff = KeyPointBuff(skeleton_action_frames) + if self.with_video_action: + video_action_cfg = self.cfg['VIDEO_ACTION'] + + basemode = video_action_cfg['basemode'] + self.modebase[basemode] = True + + video_action_model_dir = video_action_cfg['model_dir'] + video_action_batch_size = video_action_cfg['batch_size'] + short_size = video_action_cfg["short_size"] + target_size = video_action_cfg["target_size"] + + self.video_action_predictor = VideoActionRecognizer( + model_dir=video_action_model_dir, + short_size=short_size, + target_size=target_size, + device=device, + run_mode=run_mode, + batch_size=video_action_batch_size, + trt_min_shape=trt_min_shape, + trt_max_shape=trt_max_shape, + trt_opt_shape=trt_opt_shape, + trt_calib_mode=trt_calib_mode, + cpu_threads=cpu_threads, + enable_mkldnn=enable_mkldnn) + if self.with_mtmct: reid_cfg = self.cfg['REID'] model_dir = reid_cfg['model_dir'] @@ -523,9 +552,12 @@ class PipePredictor(object): entrance = [0, height / 2., width, height / 2.] video_fps = fps + video_action_imgs = [] + while (1): if frame_id % 10 == 0: print('frame id: ', frame_id) + ret, frame = capture.read() if not ret: break @@ -660,10 +692,34 @@ class PipePredictor(object): self.pipeline_res.clear('reid') if self.with_video_action: - #predeal, get what your model need - #predict, model preprocess\run\postprocess - #postdeal, interact with pipeline - pass + # get the params + frame_len = self.cfg["VIDEO_ACTION"]["frame_len"] + sample_freq = self.cfg["VIDEO_ACTION"]["sample_freq"] + + if sample_freq * frame_len > frame_count: # video is too short + sample_freq = int(frame_count / frame_len) + + # filter the warmup frames + if frame_id > self.warmup_frame: + self.pipe_timer.module_time['video_action'].start() + + # collect frames + if frame_id % sample_freq == 0: + video_action_imgs.append(frame) + + # the number of collected frames is enough to predict video action + if len(video_action_imgs) == frame_len: + classes, scores = self.video_action_predictor.predict( + video_action_imgs) + if frame_id > self.warmup_frame: + self.pipe_timer.module_time['video_action'].end() + + video_action_res = {"class": classes[0], "score": scores[0]} + self.pipeline_res.update(video_action_res, 'video_action') + + print("video_action_res:", video_action_res) + + video_action_imgs.clear() # next clip self.collector.append(frame_id, self.pipeline_res) @@ -744,10 +800,21 @@ class PipePredictor(object): returnimg=True) skeleton_action_res = result.get('skeleton_action') - if skeleton_action_res is not None: - image = visualize_action(image, mot_res['boxes'], - self.skeleton_action_visual_helper, - "SkeletonAction") + video_action_res = result.get('video_action') + if skeleton_action_res is not None or video_action_res is not None: + video_action_score = None + action_visual_helper = None + if video_action_res and video_action_res["class"] == 1: + video_action_score = video_action_res["score"] + if skeleton_action_res: + action_visual_helper = self.skeleton_action_visual_helper + image = visualize_action( + image, + mot_res['boxes'], + action_visual_collector=action_visual_helper, + action_text="SkeletonAction", + video_action_score=video_action_score, + video_action_text="Fight") return image @@ -784,6 +851,7 @@ class PipePredictor(object): def main(): cfg = merge_cfg(FLAGS) print_arguments(cfg) + pipeline = Pipeline( cfg, FLAGS.image_file, FLAGS.image_dir, FLAGS.video_file, FLAGS.video_dir, FLAGS.camera_id, FLAGS.device, FLAGS.run_mode, diff --git a/deploy/python/video_action_infer.py b/deploy/python/video_action_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..34f63a94cb23c142c89cbb777d57a0db68c024f6 --- /dev/null +++ b/deploy/python/video_action_infer.py @@ -0,0 +1,297 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import yaml +import glob + +import cv2 +import numpy as np +import math +import paddle +import sys +from collections import Sequence +import paddle.nn.functional as F + +# add deploy path of PadleDetection to sys.path +parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2))) +sys.path.insert(0, parent_path) + +from paddle.inference import Config, create_predictor +from utils import argsparser, Timer, get_current_memory_mb +from benchmark_utils import PaddleInferBenchmark +from infer import Detector, print_arguments +from video_action_preprocess import VideoDecoder, Sampler, Scale, CenterCrop, Normalization, Image2Array + + +def softmax(x): + f_x = np.exp(x) / np.sum(np.exp(x)) + return f_x + + +class VideoActionRecognizer(object): + """ + Args: + model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml + device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU + run_mode (str): mode of running(paddle/trt_fp32/trt_fp16) + batch_size (int): size of pre batch in inference + trt_min_shape (int): min shape for dynamic shape in trt + trt_max_shape (int): max shape for dynamic shape in trt + trt_opt_shape (int): opt shape for dynamic shape in trt + trt_calib_mode (bool): If the model is produced by TRT offline quantitative + calibration, trt_calib_mode need to set True + cpu_threads (int): cpu threads + enable_mkldnn (bool): whether to open MKLDNN + """ + + def __init__(self, + model_dir, + device='CPU', + run_mode='paddle', + num_seg=8, + seg_len=1, + short_size=256, + target_size=224, + top_k=1, + batch_size=1, + trt_min_shape=1, + trt_max_shape=1280, + trt_opt_shape=640, + trt_calib_mode=False, + cpu_threads=1, + enable_mkldnn=False, + ir_optim=True): + + self.num_seg = num_seg + self.seg_len = seg_len + self.short_size = short_size + self.target_size = target_size + self.top_k = top_k + + assert batch_size == 1, "VideoActionRecognizer only support batch_size=1 now." + + self.model_dir = model_dir + self.device = device + self.run_mode = run_mode + self.batch_size = batch_size + self.trt_min_shape = trt_min_shape + self.trt_max_shape = trt_max_shape + self.trt_opt_shape = trt_opt_shape + self.trt_calib_mode = trt_calib_mode + self.cpu_threads = cpu_threads + self.enable_mkldnn = enable_mkldnn + self.ir_optim = ir_optim + + self.recognize_times = Timer() + + model_file_path = os.path.join(model_dir, "model.pdmodel") + params_file_path = os.path.join(model_dir, "model.pdiparams") + self.config = Config(model_file_path, params_file_path) + + if device == "GPU" or device == "gpu": + self.config.enable_use_gpu(8000, 0) + else: + self.config.disable_gpu() + if self.enable_mkldnn: + # cache 10 different shapes for mkldnn to avoid memory leak + self.config.set_mkldnn_cache_capacity(10) + self.config.enable_mkldnn() + + self.config.switch_ir_optim(self.ir_optim) # default true + + precision_map = { + 'trt_int8': Config.Precision.Int8, + 'trt_fp32': Config.Precision.Float32, + 'trt_fp16': Config.Precision.Half + } + if run_mode in precision_map.keys(): + self.config.enable_tensorrt_engine( + max_batch_size=self.batch_size, + precision_mode=precision_map[run_mode]) + + self.config.enable_memory_optim() + # use zero copy + self.config.switch_use_feed_fetch_ops(False) + + self.predictor = create_predictor(self.config) + + def preprocess_batch(self, file_list): + batched_inputs = [] + for file in file_list: + inputs = self.preprocess(file) + batched_inputs.append(inputs) + batched_inputs = [ + np.concatenate([item[i] for item in batched_inputs]) + for i in range(len(batched_inputs[0])) + ] + self.input_file = file_list + return batched_inputs + + def get_timer(self): + return self.recognize_times + + def predict(self, input): + ''' + Args: + input (str) or (list): video file path or image data list + Returns: + results (dict): + ''' + + input_names = self.predictor.get_input_names() + input_tensor = self.predictor.get_input_handle(input_names[0]) + + output_names = self.predictor.get_output_names() + output_tensor = self.predictor.get_output_handle(output_names[0]) + + # preprocess + self.recognize_times.preprocess_time_s.start() + if type(input) == str: + inputs = self.preprocess_video(input) + else: + inputs = self.preprocess_frames(input) + self.recognize_times.preprocess_time_s.end() + + inputs = np.expand_dims( + inputs, axis=0).repeat( + self.batch_size, axis=0).copy() + + input_tensor.copy_from_cpu(inputs) + + # model prediction + self.recognize_times.inference_time_s.start() + self.predictor.run() + self.recognize_times.inference_time_s.end() + + output = output_tensor.copy_to_cpu() + + # postprocess + self.recognize_times.postprocess_time_s.start() + classes, scores = self.postprocess(output) + self.recognize_times.postprocess_time_s.end() + + return classes, scores + + def preprocess_frames(self, frame_list): + """ + frame_list: list, frame list + return: list + """ + + results = {} + results['frames_len'] = len(frame_list) + results["imgs"] = frame_list + + img_mean = [0.485, 0.456, 0.406] + img_std = [0.229, 0.224, 0.225] + ops = [ + Scale(self.short_size), CenterCrop(self.target_size), Image2Array(), + Normalization(img_mean, img_std) + ] + for op in ops: + results = op(results) + + res = np.expand_dims(results['imgs'], axis=0).copy() + return [res] + + def preprocess_video(self, input_file): + """ + input_file: str, file path + return: list + """ + assert os.path.isfile(input_file) is not None, "{0} not exists".format( + input_file) + + results = {'filename': input_file} + img_mean = [0.485, 0.456, 0.406] + img_std = [0.229, 0.224, 0.225] + ops = [ + VideoDecoder(), Sampler( + self.num_seg, self.seg_len, valid_mode=True), + Scale(self.short_size), CenterCrop(self.target_size), Image2Array(), + Normalization(img_mean, img_std) + ] + for op in ops: + results = op(results) + + res = np.expand_dims(results['imgs'], axis=0).copy() + return [res] + + def postprocess(self, output): + output = output.flatten() # numpy.ndarray + output = softmax(output) + classes = np.argpartition(output, -self.top_k)[-self.top_k:] + classes = classes[np.argsort(-output[classes])] + scores = output[classes] + return classes, scores + + +def main(): + if not FLAGS.run_benchmark: + assert FLAGS.batch_size == 1 + assert FLAGS.use_fp16 is False + else: + assert FLAGS.use_gpu is True + + recognizer = VideoActionRecognizer( + FLAGS.model_dir, + short_size=FLAGS.short_size, + target_size=FLAGS.target_size, + device=FLAGS.device, + run_mode=FLAGS.run_mode, + batch_size=FLAGS.batch_size, + trt_min_shape=FLAGS.trt_min_shape, + trt_max_shape=FLAGS.trt_max_shape, + trt_opt_shape=FLAGS.trt_opt_shape, + trt_calib_mode=FLAGS.trt_calib_mode, + cpu_threads=FLAGS.cpu_threads, + enable_mkldnn=FLAGS.enable_mkldnn, ) + + if not FLAGS.run_benchmark: + classes, scores = recognizer.predict(FLAGS.video_file) + print("Current video file: {}".format(FLAGS.video_file)) + print("\ttop-1 class: {0}".format(classes[0])) + print("\ttop-1 score: {0}".format(scores[0])) + else: + cm, gm, gu = get_current_memory_mb() + mems = {'cpu_rss_mb': cm, 'gpu_rss_mb': gm, 'gpu_util': gu * 100} + + perf_info = recognizer.recognize_times.report() + model_dir = FLAGS.model_dir + mode = FLAGS.run_mode + model_info = { + 'model_name': model_dir.strip('/').split('/')[-1], + 'precision': mode.split('_')[-1] + } + data_info = { + 'batch_size': FLAGS.batch_size, + 'shape': "dynamic_shape", + 'data_num': perf_info['img_num'] + } + recognize_log = PaddleInferBenchmark(recognizer.config, model_info, + data_info, perf_info, mems) + recognize_log('Fight') + + +if __name__ == '__main__': + paddle.enable_static() + parser = argsparser() + FLAGS = parser.parse_args() + print_arguments(FLAGS) + FLAGS.device = FLAGS.device.upper() + assert FLAGS.device in ['CPU', 'GPU', 'XPU' + ], "device should be CPU, GPU or XPU" + + main() diff --git a/deploy/python/video_action_preprocess.py b/deploy/python/video_action_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..f6f9f11f7aee643ebfc070073f18f7e28bebf9dd --- /dev/null +++ b/deploy/python/video_action_preprocess.py @@ -0,0 +1,545 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np +from collections.abc import Sequence +from PIL import Image +import paddle + + +class Sampler(object): + """ + Sample frames id. + NOTE: Use PIL to read image here, has diff with CV2 + Args: + num_seg(int): number of segments. + seg_len(int): number of sampled frames in each segment. + valid_mode(bool): True or False. + Returns: + frames_idx: the index of sampled #frames. + """ + + def __init__(self, + num_seg, + seg_len, + frame_interval=None, + valid_mode=True, + dense_sample=False, + linspace_sample=False, + use_pil=True): + self.num_seg = num_seg + self.seg_len = seg_len + self.frame_interval = frame_interval + self.valid_mode = valid_mode + self.dense_sample = dense_sample + self.linspace_sample = linspace_sample + self.use_pil = use_pil + + def _get(self, frames_idx, results): + data_format = results['format'] + + if data_format == "frame": + frame_dir = results['frame_dir'] + imgs = [] + for idx in frames_idx: + img = Image.open( + os.path.join(frame_dir, results['suffix'].format( + idx))).convert('RGB') + imgs.append(img) + + elif data_format == "video": + if results['backend'] == 'cv2': + frames = np.array(results['frames']) + imgs = [] + for idx in frames_idx: + imgbuf = frames[idx] + img = Image.fromarray(imgbuf, mode='RGB') + imgs.append(img) + elif results['backend'] == 'decord': + container = results['frames'] + if self.use_pil: + frames_select = container.get_batch(frames_idx) + # dearray_to_img + np_frames = frames_select.asnumpy() + imgs = [] + for i in range(np_frames.shape[0]): + imgbuf = np_frames[i] + imgs.append(Image.fromarray(imgbuf, mode='RGB')) + else: + if frames_idx.ndim != 1: + frames_idx = np.squeeze(frames_idx) + frame_dict = { + idx: container[idx].asnumpy() + for idx in np.unique(frames_idx) + } + imgs = [frame_dict[idx] for idx in frames_idx] + elif results['backend'] == 'pyav': + imgs = [] + frames = np.array(results['frames']) + for idx in frames_idx: + imgbuf = frames[idx] + imgs.append(imgbuf) + imgs = np.stack(imgs) # thwc + else: + raise NotImplementedError + else: + raise NotImplementedError + results['imgs'] = imgs # all image data + return results + + def _get_train_clips(self, num_frames): + ori_seg_len = self.seg_len * self.frame_interval + avg_interval = (num_frames - ori_seg_len + 1) // self.num_seg + + if avg_interval > 0: + base_offsets = np.arange(self.num_seg) * avg_interval + clip_offsets = base_offsets + np.random.randint( + avg_interval, size=self.num_seg) + elif num_frames > max(self.num_seg, ori_seg_len): + clip_offsets = np.sort( + np.random.randint( + num_frames - ori_seg_len + 1, size=self.num_seg)) + elif avg_interval == 0: + ratio = (num_frames - ori_seg_len + 1.0) / self.num_seg + clip_offsets = np.around(np.arange(self.num_seg) * ratio) + else: + clip_offsets = np.zeros((self.num_seg, ), dtype=np.int) + return clip_offsets + + def _get_test_clips(self, num_frames): + ori_seg_len = self.seg_len * self.frame_interval + avg_interval = (num_frames - ori_seg_len + 1) / float(self.num_seg) + if num_frames > ori_seg_len - 1: + base_offsets = np.arange(self.num_seg) * avg_interval + clip_offsets = (base_offsets + avg_interval / 2.0).astype(np.int) + else: + clip_offsets = np.zeros((self.num_seg, ), dtype=np.int) + return clip_offsets + + def __call__(self, results): + """ + Args: + frames_len: length of frames. + return: + sampling id. + """ + frames_len = int(results['frames_len']) # total number of frames + + frames_idx = [] + if self.frame_interval is not None: + assert isinstance(self.frame_interval, int) + if not self.valid_mode: + offsets = self._get_train_clips(frames_len) + else: + offsets = self._get_test_clips(frames_len) + + offsets = offsets[:, None] + np.arange(self.seg_len)[ + None, :] * self.frame_interval + offsets = np.concatenate(offsets) + + offsets = offsets.reshape((-1, self.seg_len)) + offsets = np.mod(offsets, frames_len) + offsets = np.concatenate(offsets) + + if results['format'] == 'video': + frames_idx = offsets + elif results['format'] == 'frame': + frames_idx = list(offsets + 1) + else: + raise NotImplementedError + + return self._get(frames_idx, results) + + print("self.frame_interval:", self.frame_interval) + + if self.linspace_sample: # default if False + if 'start_idx' in results and 'end_idx' in results: + offsets = np.linspace(results['start_idx'], results['end_idx'], + self.num_seg) + else: + offsets = np.linspace(0, frames_len - 1, self.num_seg) + offsets = np.clip(offsets, 0, frames_len - 1).astype(np.int64) + if results['format'] == 'video': + frames_idx = list(offsets) + frames_idx = [x % frames_len for x in frames_idx] + elif results['format'] == 'frame': + frames_idx = list(offsets + 1) + else: + raise NotImplementedError + return self._get(frames_idx, results) + + average_dur = int(frames_len / self.num_seg) + + print("results['format']:", results['format']) + + if self.dense_sample: # For ppTSM, default is False + if not self.valid_mode: # train + sample_pos = max(1, 1 + frames_len - 64) + t_stride = 64 // self.num_seg + start_idx = 0 if sample_pos == 1 else np.random.randint( + 0, sample_pos - 1) + offsets = [(idx * t_stride + start_idx) % frames_len + 1 + for idx in range(self.num_seg)] + frames_idx = offsets + else: + sample_pos = max(1, 1 + frames_len - 64) + t_stride = 64 // self.num_seg + start_list = np.linspace(0, sample_pos - 1, num=10, dtype=int) + offsets = [] + for start_idx in start_list.tolist(): + offsets += [(idx * t_stride + start_idx) % frames_len + 1 + for idx in range(self.num_seg)] + frames_idx = offsets + else: + for i in range(self.num_seg): + idx = 0 + if not self.valid_mode: + if average_dur >= self.seg_len: + idx = random.randint(0, average_dur - self.seg_len) + idx += i * average_dur + elif average_dur >= 1: + idx += i * average_dur + else: + idx = i + else: + if average_dur >= self.seg_len: + idx = (average_dur - 1) // 2 + idx += i * average_dur + elif average_dur >= 1: + idx += i * average_dur + else: + idx = i + + for jj in range(idx, idx + self.seg_len): + if results['format'] == 'video': + frames_idx.append(int(jj % frames_len)) + elif results['format'] == 'frame': + frames_idx.append(jj + 1) + + elif results['format'] == 'MRI': + frames_idx.append(jj) + else: + raise NotImplementedError + + return self._get(frames_idx, results) + + +class Scale(object): + """ + Scale images. + Args: + short_size(float | int): Short size of an image will be scaled to the short_size. + fixed_ratio(bool): Set whether to zoom according to a fixed ratio. default: True + do_round(bool): Whether to round up when calculating the zoom ratio. default: False + backend(str): Choose pillow or cv2 as the graphics processing backend. default: 'pillow' + """ + + def __init__(self, + short_size, + fixed_ratio=True, + keep_ratio=None, + do_round=False, + backend='pillow'): + self.short_size = short_size + assert (fixed_ratio and not keep_ratio) or ( + not fixed_ratio + ), "fixed_ratio and keep_ratio cannot be true at the same time" + self.fixed_ratio = fixed_ratio + self.keep_ratio = keep_ratio + self.do_round = do_round + + assert backend in [ + 'pillow', 'cv2' + ], "Scale's backend must be pillow or cv2, but get {backend}" + + self.backend = backend + + def __call__(self, results): + """ + Performs resize operations. + Args: + imgs (Sequence[PIL.Image]): List where each item is a PIL.Image. + For example, [PIL.Image0, PIL.Image1, PIL.Image2, ...] + return: + resized_imgs: List where each item is a PIL.Image after scaling. + """ + imgs = results['imgs'] + resized_imgs = [] + for i in range(len(imgs)): + img = imgs[i] + if isinstance(img, np.ndarray): + h, w, _ = img.shape + elif isinstance(img, Image.Image): + w, h = img.size + else: + raise NotImplementedError + + if w <= h: + ow = self.short_size + if self.fixed_ratio: # default is True + oh = int(self.short_size * 4.0 / 3.0) + elif not self.keep_ratio: # no + oh = self.short_size + else: + scale_factor = self.short_size / w + oh = int(h * float(scale_factor) + + 0.5) if self.do_round else int(h * + self.short_size / w) + ow = int(w * float(scale_factor) + + 0.5) if self.do_round else int(w * + self.short_size / h) + else: + oh = self.short_size + if self.fixed_ratio: + ow = int(self.short_size * 4.0 / 3.0) + elif not self.keep_ratio: # no + ow = self.short_size + else: + scale_factor = self.short_size / h + oh = int(h * float(scale_factor) + + 0.5) if self.do_round else int(h * + self.short_size / w) + ow = int(w * float(scale_factor) + + 0.5) if self.do_round else int(w * + self.short_size / h) + + if type(img) == np.ndarray: + img = Image.fromarray(img, mode='RGB') + + if self.backend == 'pillow': + resized_imgs.append(img.resize((ow, oh), Image.BILINEAR)) + elif self.backend == 'cv2' and (self.keep_ratio is not None): + resized_imgs.append( + cv2.resize( + img, (ow, oh), interpolation=cv2.INTER_LINEAR)) + else: + resized_imgs.append( + Image.fromarray( + cv2.resize( + np.asarray(img), (ow, oh), + interpolation=cv2.INTER_LINEAR))) + results['imgs'] = resized_imgs + return results + + +class CenterCrop(object): + """ + Center crop images + Args: + target_size(int): Center crop a square with the target_size from an image. + do_round(bool): Whether to round up the coordinates of the upper left corner of the cropping area. default: True + """ + + def __init__(self, target_size, do_round=True, backend='pillow'): + self.target_size = target_size + self.do_round = do_round + self.backend = backend + + def __call__(self, results): + """ + Performs Center crop operations. + Args: + imgs: List where each item is a PIL.Image. + For example, [PIL.Image0, PIL.Image1, PIL.Image2, ...] + return: + ccrop_imgs: List where each item is a PIL.Image after Center crop. + """ + imgs = results['imgs'] + ccrop_imgs = [] + th, tw = self.target_size, self.target_size + if isinstance(imgs, paddle.Tensor): + h, w = imgs.shape[-2:] + x1 = int(round((w - tw) / 2.0)) if self.do_round else (w - tw) // 2 + y1 = int(round((h - th) / 2.0)) if self.do_round else (h - th) // 2 + ccrop_imgs = imgs[:, :, y1:y1 + th, x1:x1 + tw] + else: + for img in imgs: + if self.backend == 'pillow': + w, h = img.size + elif self.backend == 'cv2': + h, w, _ = img.shape + else: + raise NotImplementedError + assert (w >= self.target_size) and (h >= self.target_size), \ + "image width({}) and height({}) should be larger than crop size".format( + w, h, self.target_size) + x1 = int(round((w - tw) / 2.0)) if self.do_round else ( + w - tw) // 2 + y1 = int(round((h - th) / 2.0)) if self.do_round else ( + h - th) // 2 + if self.backend == 'cv2': + ccrop_imgs.append(img[y1:y1 + th, x1:x1 + tw]) + elif self.backend == 'pillow': + ccrop_imgs.append(img.crop((x1, y1, x1 + tw, y1 + th))) + results['imgs'] = ccrop_imgs + return results + + +class Image2Array(object): + """ + transfer PIL.Image to Numpy array and transpose dimensions from 'dhwc' to 'dchw'. + Args: + transpose: whether to transpose or not, default True, False for slowfast. + """ + + def __init__(self, transpose=True, data_format='tchw'): + assert data_format in [ + 'tchw', 'cthw' + ], "Target format must in ['tchw', 'cthw'], but got {data_format}" + self.transpose = transpose + self.data_format = data_format + + def __call__(self, results): + """ + Performs Image to NumpyArray operations. + Args: + imgs: List where each item is a PIL.Image. + For example, [PIL.Image0, PIL.Image1, PIL.Image2, ...] + return: + np_imgs: Numpy array. + """ + imgs = results['imgs'] + if 'backend' in results and results[ + 'backend'] == 'pyav': # [T,H,W,C] in [0, 1] + if self.transpose: + if self.data_format == 'tchw': + t_imgs = imgs.transpose((0, 3, 1, 2)) # tchw + else: + t_imgs = imgs.transpose((3, 0, 1, 2)) # cthw + results['imgs'] = t_imgs + else: + t_imgs = np.stack(imgs).astype('float32') + if self.transpose: + if self.data_format == 'tchw': + t_imgs = t_imgs.transpose(0, 3, 1, 2) # tchw + else: + t_imgs = t_imgs.transpose(3, 0, 1, 2) # cthw + results['imgs'] = t_imgs + return results + + +class VideoDecoder(object): + """ + Decode mp4 file to frames. + Args: + filepath: the file path of mp4 file + """ + + def __init__(self, + backend='cv2', + mode='train', + sampling_rate=32, + num_seg=8, + num_clips=1, + target_fps=30): + + self.backend = backend + # params below only for TimeSformer + self.mode = mode + self.sampling_rate = sampling_rate + self.num_seg = num_seg + self.num_clips = num_clips + self.target_fps = target_fps + + def __call__(self, results): + """ + Perform mp4 decode operations. + return: + List where each item is a numpy array after decoder. + """ + file_path = results['filename'] + results['format'] = 'video' + results['backend'] = self.backend + + if self.backend == 'cv2': # here + cap = cv2.VideoCapture(file_path) + videolen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + + sampledFrames = [] + for i in range(videolen): + ret, frame = cap.read() + # maybe first frame is empty + if ret == False: + continue + img = frame[:, :, ::-1] + sampledFrames.append(img) + results['frames'] = sampledFrames + results['frames_len'] = len(sampledFrames) + + elif self.backend == 'decord': + container = de.VideoReader(file_path) + frames_len = len(container) + results['frames'] = container + results['frames_len'] = frames_len + else: + raise NotImplementedError + return results + + +class Normalization(object): + """ + Normalization. + Args: + mean(Sequence[float]): mean values of different channels. + std(Sequence[float]): std values of different channels. + tensor_shape(list): size of mean, default [3,1,1]. For slowfast, [1,1,1,3] + """ + + def __init__(self, mean, std, tensor_shape=[3, 1, 1], inplace=False): + if not isinstance(mean, Sequence): + raise TypeError( + 'Mean must be list, tuple or np.ndarray, but got {type(mean)}') + if not isinstance(std, Sequence): + raise TypeError( + 'Std must be list, tuple or np.ndarray, but got {type(std)}') + + self.inplace = inplace + if not inplace: + self.mean = np.array(mean).reshape(tensor_shape).astype(np.float32) + self.std = np.array(std).reshape(tensor_shape).astype(np.float32) + else: + self.mean = np.array(mean, dtype=np.float32) + self.std = np.array(std, dtype=np.float32) + + def __call__(self, results): + """ + Performs normalization operations. + Args: + imgs: Numpy array. + return: + np_imgs: Numpy array after normalization. + """ + + if self.inplace: # default is False + n = len(results['imgs']) + h, w, c = results['imgs'][0].shape + norm_imgs = np.empty((n, h, w, c), dtype=np.float32) + for i, img in enumerate(results['imgs']): + norm_imgs[i] = img + + for img in norm_imgs: # [n,h,w,c] + mean = np.float64(self.mean.reshape(1, -1)) # [1, 3] + stdinv = 1 / np.float64(self.std.reshape(1, -1)) # [1, 3] + cv2.subtract(img, mean, img) + cv2.multiply(img, stdinv, img) + else: + imgs = results['imgs'] + norm_imgs = imgs / 255.0 + norm_imgs -= self.mean + norm_imgs /= self.std + if 'backend' in results and results['backend'] == 'pyav': + norm_imgs = paddle.to_tensor(norm_imgs, dtype=paddle.float32) + results['imgs'] = norm_imgs + return results diff --git a/deploy/python/visualize.py b/deploy/python/visualize.py index c26a6e4673c846a5ad0cb0d2c098edea207a7a60..4bffc4d8040861f71a6311fd8d83f07b82ff7d42 100644 --- a/deploy/python/visualize.py +++ b/deploy/python/visualize.py @@ -365,15 +365,35 @@ def visualize_attr(im, results, boxes=None): return im -def visualize_action(im, mot_boxes, action_visual_collector, action_text=""): +def visualize_action(im, + mot_boxes, + action_visual_collector=None, + action_text="", + video_action_score=None, + video_action_text=""): im = cv2.imread(im) if isinstance(im, str) else im - id_detected = action_visual_collector.get_visualize_ids() + im_h, im_w = im.shape[:2] + text_scale = max(1, im.shape[1] / 1600.) - for mot_box in mot_boxes: - # mot_box is a format with [mot_id, class, score, xmin, ymin, w, h] - if mot_box[0] in id_detected: - text_position = (int(mot_box[3] + mot_box[5] * 0.75), - int(mot_box[4] - 10)) - cv2.putText(im, action_text, text_position, cv2.FONT_HERSHEY_PLAIN, - text_scale, (0, 0, 255), 2) + text_thickness = 2 + + if action_visual_collector: + id_detected = action_visual_collector.get_visualize_ids() + for mot_box in mot_boxes: + # mot_box is a format with [mot_id, class, score, xmin, ymin, w, h] + if mot_box[0] in id_detected: + text_position = (int(mot_box[3] + mot_box[5] * 0.75), + int(mot_box[4] - 10)) + cv2.putText(im, action_text, text_position, + cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 0, 255), 2) + + if video_action_score: + cv2.putText( + im, + video_action_text + ': %.2f' % video_action_score, + (int(im_w / 2), int(15 * text_scale) + 5), + cv2.FONT_ITALIC, + text_scale, (0, 0, 255), + thickness=text_thickness) + return im