diff --git a/demo/key_point_detection/openpose_body/demo.jpg b/demo/key_point_detection/openpose_body/demo.jpg new file mode 100644 index 0000000000000000000000000000000000000000..76f33d26244b3c0ee1392f800d3bfb7bd7720257 Binary files /dev/null and b/demo/key_point_detection/openpose_body/demo.jpg differ diff --git a/demo/key_point_detection/openpose_body/predict.py b/demo/key_point_detection/openpose_body/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..93517892e59cd9a6fed7df7df813e313128cf654 --- /dev/null +++ b/demo/key_point_detection/openpose_body/predict.py @@ -0,0 +1,10 @@ +import paddle +import paddlehub as hub + +if __name__ == "__main__": + + paddle.disable_static() + model = hub.Module(name='openpose_body_estimation') + model.eval() + out1, out2 = model.predict("demo.jpg") + print(out1.shape) diff --git a/demo/key_point_detection/openpose_hands/demo.jpg b/demo/key_point_detection/openpose_hands/demo.jpg new file mode 100644 index 0000000000000000000000000000000000000000..76f33d26244b3c0ee1392f800d3bfb7bd7720257 Binary files /dev/null and b/demo/key_point_detection/openpose_hands/demo.jpg differ diff --git a/demo/key_point_detection/openpose_hands/predict.py b/demo/key_point_detection/openpose_hands/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..8c792dbb9a9cd665c349feffc1cbfb8b0ea4c307 --- /dev/null +++ b/demo/key_point_detection/openpose_hands/predict.py @@ -0,0 +1,9 @@ +import paddle +import paddlehub as hub + +if __name__ == "__main__": + + paddle.disable_static() + model = hub.Module(name='openpose_hands_estimation') + model.eval() + all_hand_peaks = model.predict("demo.jpg") diff --git a/hub_module/modules/image/keypoint_detection/openpose_body_estimation/module.py b/hub_module/modules/image/keypoint_detection/openpose_body_estimation/module.py new file mode 100644 index 0000000000000000000000000000000000000000..d830613f48d5a4e32e030b369c62d624b35c6a43 --- /dev/null +++ b/hub_module/modules/image/keypoint_detection/openpose_body_estimation/module.py @@ -0,0 +1,196 @@ +import os +import copy +from collections import OrderedDict + +import cv2 +import paddle +import paddle.nn as nn +import numpy as np +from paddlehub.module.module import moduleinfo +from paddlehub.process.transforms import ResizeScaling, PadDownRight, Normalize, RemovePadding, GetPeak, Connection, DrawPose, Candidate + + +@moduleinfo(name="openpose_body_estimation", + type="CV/image_editing", + author="paddlepaddle", + author_email="", + summary="Openpose_body_estimation is a body pose estimation model based on Realtime Multi-Person 2D Pose \ + Estimation using Part Affinity Fields.", + version="1.0.0") +class BodyposeModel(nn.Layer): + """BodyposeModel + Args: + load_checkpoint(str): Checkpoint save path, default is None. + visualization (bool): Whether to save the estimation result. Default is True. + """ + def __init__(self, load_checkpoint: str = None, visualization: bool = True): + super(BodyposeModel, self).__init__() + + self.resize_func = ResizeScaling() + self.pad_func = PadDownRight() + self.norm_func = Normalize(std=[1, 1, 1]) + self.remove_pad = RemovePadding() + self.get_peak = GetPeak() + self.get_connection = Connection() + self.get_candidate = Candidate() + self.draw_pose = DrawPose() + self.visualization = visualization + + no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1', \ + 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2', \ + 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1', \ + 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1'] + blocks = {} + block0 = OrderedDict([('conv1_1', [3, 64, 3, 1, 1]), ('conv1_2', [64, 64, 3, 1, 1]), ('pool1_stage1', [2, 2, + 0]), + ('conv2_1', [64, 128, 3, 1, 1]), ('conv2_2', [128, 128, 3, 1, 1]), + ('pool2_stage1', [2, 2, 0]), ('conv3_1', [128, 256, 3, 1, 1]), + ('conv3_2', [256, 256, 3, 1, 1]), ('conv3_3', [256, 256, 3, 1, 1]), + ('conv3_4', [256, 256, 3, 1, 1]), ('pool3_stage1', [2, 2, 0]), + ('conv4_1', [256, 512, 3, 1, 1]), ('conv4_2', [512, 512, 3, 1, 1]), + ('conv4_3_CPM', [512, 256, 3, 1, 1]), ('conv4_4_CPM', [256, 128, 3, 1, 1])]) + + block1_1 = OrderedDict([('conv5_1_CPM_L1', [128, 128, 3, 1, 1]), ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]), ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]), + ('conv5_5_CPM_L1', [512, 38, 1, 1, 0])]) + + block1_2 = OrderedDict([('conv5_1_CPM_L2', [128, 128, 3, 1, 1]), ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]), ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]), + ('conv5_5_CPM_L2', [512, 19, 1, 1, 0])]) + blocks['block1_1'] = block1_1 + blocks['block1_2'] = block1_2 + + self.model0 = self.make_layers(block0, no_relu_layers) + + for i in range(2, 7): + blocks['block%d_1' % i] = OrderedDict([('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]), + ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0])]) + + blocks['block%d_2' % i] = OrderedDict([('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]), + ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0])]) + + for k in blocks.keys(): + blocks[k] = self.make_layers(blocks[k], no_relu_layers) + + self.model1_1 = blocks['block1_1'] + self.model2_1 = blocks['block2_1'] + self.model3_1 = blocks['block3_1'] + self.model4_1 = blocks['block4_1'] + self.model5_1 = blocks['block5_1'] + self.model6_1 = blocks['block6_1'] + + self.model1_2 = blocks['block1_2'] + self.model2_2 = blocks['block2_2'] + self.model3_2 = blocks['block3_2'] + self.model4_2 = blocks['block4_2'] + self.model5_2 = blocks['block5_2'] + self.model6_2 = blocks['block6_2'] + + if load_checkpoint is not None: + model_dict = paddle.load(load_checkpoint)[0] + self.set_dict(model_dict) + print("load custom checkpoint success") + + else: + checkpoint = os.path.join(self.directory, 'body_estimation.pdparams') + if not os.path.exists(checkpoint): + os.system( + 'wget https://bj.bcebos.com/paddlehub/model/image/keypoint_detection/body_estimation.pdparams -O ' + + checkpoint) + model_dict = paddle.load(checkpoint)[0] + self.set_dict(model_dict) + print("load pretrained checkpoint success") + + def make_layers(self, block: dict, no_relu_layers: list): + layers = [] + for layer_name, v in block.items(): + if 'pool' in layer_name: + layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2]) + layers.append((layer_name, layer)) + else: + conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1], kernel_size=v[2], stride=v[3], padding=v[4]) + layers.append((layer_name, conv2d)) + if layer_name not in no_relu_layers: + layers.append(('relu_' + layer_name, nn.ReLU())) + layers = tuple(layers) + return nn.Sequential(*layers) + + def transform(self, orgimg: np.ndarray, scale_search: float = 0.5): + process = self.resize_func(orgimg, scale_search) + imageToTest_padded, pad = self.pad_func(process) + process = self.norm_func(imageToTest_padded) + process = np.ascontiguousarray(np.transpose(process[:, :, :, np.newaxis], (3, 2, 0, 1))).astype("float32") + + return process, imageToTest_padded, pad + + def forward(self, x: paddle.Tensor): + + out1 = self.model0(x) + + out1_1 = self.model1_1(out1) + out1_2 = self.model1_2(out1) + out2 = paddle.concat([out1_1, out1_2, out1], 1) + + out2_1 = self.model2_1(out2) + out2_2 = self.model2_2(out2) + out3 = paddle.concat([out2_1, out2_2, out1], 1) + + out3_1 = self.model3_1(out3) + out3_2 = self.model3_2(out3) + out4 = paddle.concat([out3_1, out3_2, out1], 1) + + out4_1 = self.model4_1(out4) + out4_2 = self.model4_2(out4) + out5 = paddle.concat([out4_1, out4_2, out1], 1) + + out5_1 = self.model5_1(out5) + out5_2 = self.model5_2(out5) + out6 = paddle.concat([out5_1, out5_2, out1], 1) + + out6_1 = self.model6_1(out6) + out6_2 = self.model6_2(out6) + + return out6_1, out6_2 + + def predict(self, img_path: str, save_path: str = "result"): + orgImg = cv2.imread(img_path) + data, imageToTest_padded, pad = self.transform(orgImg) + Mconv7_stage6_L1, Mconv7_stage6_L2 = self.forward(paddle.to_tensor(data)) + Mconv7_stage6_L1 = Mconv7_stage6_L1.numpy() + Mconv7_stage6_L2 = Mconv7_stage6_L2.numpy() + + heatmap_avg = self.remove_pad(Mconv7_stage6_L2, imageToTest_padded, orgImg, pad) + paf_avg = self.remove_pad(Mconv7_stage6_L1, imageToTest_padded, orgImg, pad) + + all_peaks = self.get_peak(heatmap_avg) + connection_all, special_k = self.get_connection(all_peaks, paf_avg, orgImg) + candidate, subset = self.get_candidate(all_peaks, connection_all, special_k) + + if self.visualization: + canvas = copy.deepcopy(orgImg) + canvas = self.draw_pose(canvas, candidate, subset) + if not os.path.exists(save_path): + os.mkdir(save_path) + save_path = os.path.join(save_path, img_path.rsplit("/", 1)[-1]) + cv2.imwrite(save_path, canvas) + return candidate, subset + + +if __name__ == "__main__": + import numpy as np + + paddle.disable_static() + model = BodyposeModel() + model.eval() + out1, out2 = model.predict("demo.jpg") + print(out1.shape) diff --git a/hub_module/modules/image/keypoint_detection/openpose_hands_estimation/module.py b/hub_module/modules/image/keypoint_detection/openpose_hands_estimation/module.py new file mode 100644 index 0000000000000000000000000000000000000000..34f00bf3caa4e451809bb6aa84483b700e830787 --- /dev/null +++ b/hub_module/modules/image/keypoint_detection/openpose_hands_estimation/module.py @@ -0,0 +1,187 @@ +import os +import copy +from collections import OrderedDict + +import cv2 +import paddle +import numpy as np +import paddle.nn as nn +import paddlehub as hub +from skimage.measure import label +from scipy.ndimage.filters import gaussian_filter +from paddlehub.module.module import moduleinfo +from paddlehub.process.functional import npmax +from paddlehub.process.transforms import HandDetect, ResizeScaling, PadDownRight, RemovePadding, DrawPose, DrawHandPose, Normalize + + +@moduleinfo(name="openpose_hands_estimation", + type="CV/image_editing", + author="paddlepaddle", + author_email="", + summary="Openpose_hands_estimation is a hand pose estimation model based on Hand Keypoint Detection in \ + Single Images using Multiview Bootstrapping.", + version="1.0.0") +class HandposeModel(nn.Layer): + """HandposeModel + Args: + load_checkpoint(str): Checkpoint save path, default is None. + visualization (bool): Whether to save the estimation result. Default is True. + """ + def __init__(self, load_checkpoint: str = None, visualization: bool = True): + super(HandposeModel, self).__init__() + self.visualization = visualization + self.hand_detect = HandDetect() + self.resize_func = ResizeScaling() + self.pad_func = PadDownRight() + self.remove_pad = RemovePadding() + self.draw_pose = DrawPose() + self.draw_hand = DrawHandPose() + self.norm_func = Normalize(std=[1, 1, 1]) + no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3', \ + 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6'] + + block1_0 = OrderedDict([('conv1_1', [3, 64, 3, 1, 1]), ('conv1_2', [64, 64, 3, 1, 1]), + ('pool1_stage1', [2, 2, 0]), ('conv2_1', [64, 128, 3, 1, 1]), + ('conv2_2', [128, 128, 3, 1, 1]), ('pool2_stage1', [2, 2, 0]), + ('conv3_1', [128, 256, 3, 1, 1]), ('conv3_2', [256, 256, 3, 1, 1]), + ('conv3_3', [256, 256, 3, 1, 1]), ('conv3_4', [256, 256, 3, 1, 1]), + ('pool3_stage1', [2, 2, 0]), ('conv4_1', [256, 512, 3, 1, 1]), + ('conv4_2', [512, 512, 3, 1, 1]), ('conv4_3', [512, 512, 3, 1, 1]), + ('conv4_4', [512, 512, 3, 1, 1]), ('conv5_1', [512, 512, 3, 1, 1]), + ('conv5_2', [512, 512, 3, 1, 1]), ('conv5_3_CPM', [512, 128, 3, 1, 1])]) + + block1_1 = OrderedDict([('conv6_1_CPM', [128, 512, 1, 1, 0]), ('conv6_2_CPM', [512, 22, 1, 1, 0])]) + + blocks = {} + blocks['block1_0'] = block1_0 + blocks['block1_1'] = block1_1 + + for i in range(2, 7): + blocks['block%d' % i] = OrderedDict([('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]), + ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0])]) + + for k in blocks.keys(): + blocks[k] = self.make_layers(blocks[k], no_relu_layers) + + self.model1_0 = blocks['block1_0'] + self.model1_1 = blocks['block1_1'] + self.model2 = blocks['block2'] + self.model3 = blocks['block3'] + self.model4 = blocks['block4'] + self.model5 = blocks['block5'] + self.model6 = blocks['block6'] + + if load_checkpoint is not None: + model_dict = paddle.load(load_checkpoint)[0] + self.set_dict(model_dict) + print("load custom checkpoint success") + + else: + checkpoint = os.path.join(self.directory, 'hand_estimation.pdparams') + if not os.path.exists(checkpoint): + os.system( + 'wget https://bj.bcebos.com/paddlehub/model/image/keypoint_detection/hand_estimation.pdparams -O ' + + checkpoint) + model_dict = paddle.load(checkpoint)[0] + self.set_dict(model_dict) + print("load pretrained checkpoint success") + + def make_layers(self, block: dict, no_relu_layers: list): + layers = [] + for layer_name, v in block.items(): + if 'pool' in layer_name: + layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2]) + layers.append((layer_name, layer)) + else: + conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1], kernel_size=v[2], stride=v[3], padding=v[4]) + layers.append((layer_name, conv2d)) + if layer_name not in no_relu_layers: + layers.append(('relu_' + layer_name, nn.ReLU())) + layers = tuple(layers) + return nn.Sequential(*layers) + + def forward(self, x: paddle.Tensor): + out1_0 = self.model1_0(x) + out1_1 = self.model1_1(out1_0) + concat_stage2 = paddle.concat([out1_1, out1_0], 1) + out_stage2 = self.model2(concat_stage2) + concat_stage3 = paddle.concat([out_stage2, out1_0], 1) + out_stage3 = self.model3(concat_stage3) + concat_stage4 = paddle.concat([out_stage3, out1_0], 1) + out_stage4 = self.model4(concat_stage4) + concat_stage5 = paddle.concat([out_stage4, out1_0], 1) + out_stage5 = self.model5(concat_stage5) + concat_stage6 = paddle.concat([out_stage5, out1_0], 1) + out_stage6 = self.model6(concat_stage6) + return out_stage6 + + def hand_estimation(self, handimg: np.ndarray, scale_search: list): + heatmap_avg = np.zeros((handimg.shape[0], handimg.shape[1], 22)) + + for scale in scale_search: + process = self.resize_func(handimg, scale) + imageToTest_padded, pad = self.pad_func(process) + process = self.norm_func(imageToTest_padded) + process = np.ascontiguousarray(np.transpose(process[:, :, :, np.newaxis], (3, 2, 0, 1))).astype("float32") + data = self.forward(paddle.to_tensor(process)) + data = data.numpy() + heatmap = self.remove_pad(data, imageToTest_padded, handimg, pad) + heatmap_avg += heatmap / len(scale_search) + + all_peaks = [] + for part in range(21): + map_ori = heatmap_avg[:, :, part] + one_heatmap = gaussian_filter(map_ori, sigma=3) + binary = np.ascontiguousarray(one_heatmap > 0.05, dtype=np.uint8) + if np.sum(binary) == 0: + all_peaks.append([0, 0]) + continue + label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim) + max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1 + label_img[label_img != max_index] = 0 + map_ori[label_img == 0] = 0 + + y, x = npmax(map_ori) + all_peaks.append([x, y]) + + return np.array(all_peaks) + + def predict(self, img_path: str, save_path: str = 'result', scale: list = [0.5, 1.0, 1.5, 2.0]): + self.body_model = hub.Module(name='openpose_body_estimation') + self.body_model.eval() + org_img = cv2.imread(img_path) + + candidate, subset = self.body_model.predict(img_path) + hands_list = self.hand_detect(candidate, subset, org_img) + + all_hand_peaks = [] + + for x, y, w, is_left in hands_list: + peaks = self.hand_estimation(org_img[y:y + w, x:x + w, :], scale) + peaks[:, 0] = np.where(peaks[:, 0] == 0, peaks[:, 0], peaks[:, 0] + x) + peaks[:, 1] = np.where(peaks[:, 1] == 0, peaks[:, 1], peaks[:, 1] + y) + all_hand_peaks.append(peaks) + + if self.visualization: + canvas = copy.deepcopy(org_img) + canvas = self.draw_pose(canvas, candidate, subset) + canvas = self.draw_hand(canvas, all_hand_peaks) + if not os.path.exists(save_path): + os.mkdir(save_path) + save_path = os.path.join(save_path, img_path.rsplit("/", 1)[-1]) + cv2.imwrite(save_path, canvas) + return all_hand_peaks + + +if __name__ == "__main__": + import numpy as np + + paddle.disable_static() + model = HandposeModel() + model.eval() + out1 = model.predict("detect_hand4.jpg") diff --git a/paddlehub/process/functional.py b/paddlehub/process/functional.py index 39bf0d19f076a8a413d4122ae6c7d1db5cb95c77..5f1f2f449800af734a717cb5519d631024a99d3e 100644 --- a/paddlehub/process/functional.py +++ b/paddlehub/process/functional.py @@ -137,3 +137,12 @@ def gram_matrix(data: paddle.Tensor) -> paddle.Tensor: features_t = features.transpose((0, 2, 1)) gram = features.bmm(features_t) / (ch * h * w) return gram + + +def npmax(array: np.ndarray): + """Get max value and index.""" + arrayindex = array.argmax(1) + arrayvalue = array.max(1) + i = arrayvalue.argmax() + j = arrayindex[i] + return i, j diff --git a/paddlehub/process/transforms.py b/paddlehub/process/transforms.py index f3fe4d8bc199c385218994268674b1cb1f364b58..a097ccd885da9a7f6a1d0c1a33d04e0b0d6b9630 100644 --- a/paddlehub/process/transforms.py +++ b/paddlehub/process/transforms.py @@ -13,15 +13,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import math import random +import copy +from typing import Callable from collections import OrderedDict import cv2 import numpy as np -from PIL import Image - +import matplotlib +from PIL import Image, ImageEnhance +from matplotlib import pyplot as plt +from matplotlib.figure import Figure +from scipy.ndimage.filters import gaussian_filter +from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas from paddlehub.process.functional import * +matplotlib.use('Agg') + class Compose: def __init__(self, transforms, to_rgb=True, stay_rgb=False, is_permute=True): @@ -763,3 +773,415 @@ class SetType: def __call__(self, img: np.ndarray): img = img.astype(self.type) return img + + +class ResizeScaling: + """Resize images by scaling method. + + Args: + target(int): Target image size. + interp(Callable): Interpolation method. + """ + def __init__(self, target: int = 368, interp: Callable = cv2.INTER_CUBIC): + self.target = target + self.interp = interp + + def __call__(self, img, scale_search): + scale = scale_search * self.target / img.shape[0] + resize_img = cv2.resize(img, (0, 0), fx=scale, fy=scale, interpolation=self.interp) + return resize_img + + +class PadDownRight: + """Get padding images. + + Args: + stride(int): Stride for calculate pad value for edges. + padValue(int): Initialization for new area. + """ + def __init__(self, stride: int = 8, padValue: int = 128): + self.stride = stride + self.padValue = padValue + + def __call__(self, img: np.ndarray): + h, w = img.shape[0:2] + pad = 4 * [0] + pad[2] = 0 if (h % self.stride == 0) else self.stride - (h % self.stride) # down + pad[3] = 0 if (w % self.stride == 0) else self.stride - (w % self.stride) # right + + img_padded = img + pad_up = np.tile(img_padded[0:1, :, :] * 0 + self.padValue, (pad[0], 1, 1)) + img_padded = np.concatenate((pad_up, img_padded), axis=0) + pad_left = np.tile(img_padded[:, 0:1, :] * 0 + self.padValue, (1, pad[1], 1)) + img_padded = np.concatenate((pad_left, img_padded), axis=1) + pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + self.padValue, (pad[2], 1, 1)) + img_padded = np.concatenate((img_padded, pad_down), axis=0) + pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + self.padValue, (1, pad[3], 1)) + img_padded = np.concatenate((img_padded, pad_right), axis=1) + + return img_padded, pad + + +class RemovePadding: + """Remove the padding values. + + Args: + stride(int): Scales for resizing the images. + + """ + def __init__(self, stride: int = 8): + self.stride = stride + + def __call__(self, data: np.ndarray, imageToTest_padded: np.ndarray, oriImg: np.ndarray, pad: list): + heatmap = np.transpose(np.squeeze(data), (1, 2, 0)) + heatmap = cv2.resize(heatmap, (0, 0), fx=self.stride, fy=self.stride, interpolation=cv2.INTER_CUBIC) + heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC) + + return heatmap + + +class GetPeak: + """ + Get peak values and coordinate from input. + + Args: + thresh(float): Threshold value for selecting peak value, default is 0.1. + """ + def __init__(self, thresh=0.1): + self.thresh = thresh + + def __call__(self, heatmap: np.ndarray): + all_peaks = [] + peak_counter = 0 + for part in range(18): + map_ori = heatmap[:, :, part] + one_heatmap = gaussian_filter(map_ori, sigma=3) + + map_left = np.zeros(one_heatmap.shape) + map_left[1:, :] = one_heatmap[:-1, :] + map_right = np.zeros(one_heatmap.shape) + map_right[:-1, :] = one_heatmap[1:, :] + map_up = np.zeros(one_heatmap.shape) + map_up[:, 1:] = one_heatmap[:, :-1] + map_down = np.zeros(one_heatmap.shape) + map_down[:, :-1] = one_heatmap[:, 1:] + + peaks_binary = np.logical_and.reduce( + (one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, + one_heatmap > self.thresh)) + peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse + peaks_with_score = [x + (map_ori[x[1], x[0]], ) for x in peaks] + peak_id = range(peak_counter, peak_counter + len(peaks)) + peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i], ) for i in range(len(peak_id))] + + all_peaks.append(peaks_with_score_and_id) + peak_counter += len(peaks) + + return all_peaks + + +class CalculateVector: + """ + Vector decomposition and normalization, refer Realtime Multi-Person 2D Pose Estimation using Part Affinity Fields + for more details. + + Args: + thresh(float): Threshold value for selecting candidate vector, default is 0.05. + """ + def __init__(self, thresh: float = 0.05): + self.thresh = thresh + + def __call__(self, candA: list, candB: list, nA: int, nB: int, score_mid: np.ndarray, oriImg: np.ndarray): + connection_candidate = [] + for i in range(nA): + for j in range(nB): + vec = np.subtract(candB[j][:2], candA[i][:2]) + norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1]) + 1e-5 + vec = np.divide(vec, norm) + + startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=10), \ + np.linspace(candA[i][1], candB[j][1], num=10))) + + vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \ + for I in range(len(startend))]) + vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \ + for I in range(len(startend))]) + + score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1]) + score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min(0.5 * oriImg.shape[0] / norm - 1, 0) + criterion1 = len(np.nonzero(score_midpts > self.thresh)[0]) > 0.8 * len(score_midpts) + criterion2 = score_with_dist_prior > 0 + if criterion1 and criterion2: + connection_candidate.append( + [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]]) + return connection_candidate + + +class Connection: + """Get connection for selected estimation points. + + Args: + mapIdx(list): Part Affinity Fields map index, default is None. + limbSeq(list): Peak candidate map index, default is None. + + """ + def __init__(self, mapIdx: list = None, limbSeq: list = None): + if mapIdx and limbSeq: + self.mapIdx = mapIdx + self.limbSeq = limbSeq + else: + self.mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \ + [23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \ + [55, 56], [37, 38], [45, 46]] + + self.limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ + [1, 16], [16, 18], [3, 17], [6, 18]] + self.caculate_vector = CalculateVector() + + def __call__(self, all_peaks: list, paf_avg: np.ndarray, orgimg: np.ndarray): + connection_all = [] + special_k = [] + for k in range(len(self.mapIdx)): + score_mid = paf_avg[:, :, [x - 19 for x in self.mapIdx[k]]] + candA = all_peaks[self.limbSeq[k][0] - 1] + candB = all_peaks[self.limbSeq[k][1] - 1] + nA = len(candA) + nB = len(candB) + if nA != 0 and nB != 0: + connection_candidate = self.caculate_vector(candA, candB, nA, nB, score_mid, orgimg) + connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True) + connection = np.zeros((0, 5)) + for c in range(len(connection_candidate)): + i, j, s = connection_candidate[c][0:3] + if i not in connection[:, 3] and j not in connection[:, 4]: + connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]]) + if len(connection) >= min(nA, nB): + break + + connection_all.append(connection) + else: + special_k.append(k) + connection_all.append([]) + + return connection_all, special_k + + +class Candidate: + """Select candidate for body pose estimation. + + Args: + mapIdx(list): Part Affinity Fields map index, default is None. + limbSeq(list): Peak candidate map index, default is None. + """ + def __init__(self, mapIdx: list = None, limbSeq: list = None): + if mapIdx and limbSeq: + self.mapIdx = mapIdx + self.limbSeq = limbSeq + else: + self.mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \ + [23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \ + [55, 56], [37, 38], [45, 46]] + self.limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ + [1, 16], [16, 18], [3, 17], [6, 18]] + + def __call__(self, all_peaks: list, connection_all: list, special_k: list): + subset = -1 * np.ones((0, 20)) + candidate = np.array([item for sublist in all_peaks for item in sublist]) + for k in range(len(self.mapIdx)): + if k not in special_k: + partAs = connection_all[k][:, 0] + partBs = connection_all[k][:, 1] + indexA, indexB = np.array(self.limbSeq[k]) - 1 + + for i in range(len(connection_all[k])): # = 1:size(temp,1) + found = 0 + subset_idx = [-1, -1] + for j in range(len(subset)): # 1:size(subset,1): + if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]: + subset_idx[found] = j + found += 1 + + if found == 1: + j = subset_idx[0] + if subset[j][indexB] != partBs[i]: + subset[j][indexB] = partBs[i] + subset[j][-1] += 1 + subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] + elif found == 2: # if found 2 and disjoint, merge them + j1, j2 = subset_idx + membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2] + if len(np.nonzero(membership == 2)[0]) == 0: # merge + subset[j1][:-2] += (subset[j2][:-2] + 1) + subset[j1][-2:] += subset[j2][-2:] + subset[j1][-2] += connection_all[k][i][2] + subset = np.delete(subset, j2, 0) + else: # as like found == 1 + subset[j1][indexB] = partBs[i] + subset[j1][-1] += 1 + subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] + + # if find no partA in the subset, create a new subset + elif not found and k < 17: + row = -1 * np.ones(20) + row[indexA] = partAs[i] + row[indexB] = partBs[i] + row[-1] = 2 + row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2] + subset = np.vstack([subset, row]) + # delete some rows of subset which has few parts occur + deleteIdx = [] + for i in range(len(subset)): + if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4: + deleteIdx.append(i) + subset = np.delete(subset, deleteIdx, axis=0) + return candidate, subset + + +class DrawPose: + """ + Draw Pose estimation results on canvas. + + Args: + stickwidth(int): Angle value to draw approximate ellipse curve, default is 4. + + """ + def __init__(self, stickwidth: int = 4): + self.stickwidth = stickwidth + + self.limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], [10, 11], [2, 12], [12, 13], + [13, 14], [2, 1], [1, 15], [15, 17], [1, 16], [16, 18], [3, 17], [6, 18]] + + self.colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], + [170, 255, 0], [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255], + [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], [170, 0, 255], [255, 0, 255], + [255, 0, 170], [255, 0, 85]] + + def __call__(self, canvas: np.ndarray, candidate: np.ndarray, subset: np.ndarray): + for i in range(18): + for n in range(len(subset)): + index = int(subset[n][i]) + if index == -1: + continue + x, y = candidate[index][0:2] + cv2.circle(canvas, (int(x), int(y)), 4, self.colors[i], thickness=-1) + for i in range(17): + for n in range(len(subset)): + index = subset[n][np.array(self.limbSeq[i]) - 1] + if -1 in index: + continue + cur_canvas = canvas.copy() + Y = candidate[index.astype(int), 0] + X = candidate[index.astype(int), 1] + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1])**2 + (Y[0] - Y[1])**2)**0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), self.stickwidth), int(angle), 0, 360, + 1) + cv2.fillConvexPoly(cur_canvas, polygon, self.colors[i]) + canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0) + return canvas + + +class DrawHandPose: + """ + Draw hand pose estimation results on canvas. + + Args: + show_number(bool): Whether to show estimation ids in canvas, default is False. + + """ + def __init__(self, show_number: bool = False): + self.edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ + [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] + self.show_number = show_number + + def __call__(self, canvas: np.ndarray, all_hand_peaks: list): + fig = Figure(figsize=plt.figaspect(canvas)) + + fig.subplots_adjust(0, 0, 1, 1) + fig.subplots_adjust(bottom=0, top=1, left=0, right=1) + bg = FigureCanvas(fig) + ax = fig.subplots() + ax.axis('off') + ax.imshow(canvas) + + width, height = ax.figure.get_size_inches() * ax.figure.get_dpi() + + for peaks in all_hand_peaks: + for ie, e in enumerate(self.edges): + if np.sum(np.all(peaks[e], axis=1) == 0) == 0: + x1, y1 = peaks[e[0]] + x2, y2 = peaks[e[1]] + ax.plot([x1, x2], [y1, y2], + color=matplotlib.colors.hsv_to_rgb([ie / float(len(self.edges)), 1.0, 1.0])) + + for i, keyponit in enumerate(peaks): + x, y = keyponit + ax.plot(x, y, 'r.') + if self.show_number: + ax.text(x, y, str(i)) + bg.draw() + canvas = np.frombuffer(bg.tostring_rgb(), dtype='uint8').reshape(int(height), int(width), 3) + return canvas + + +class HandDetect: + """Detect hand pose information from body pose estimation result. + + Args: + ratioWristElbow(float): Ratio to adjust the wrist center, ,default is 0.33. + """ + def __init__(self, ratioWristElbow: float = 0.33): + self.ratioWristElbow = ratioWristElbow + + def __call__(self, candidate: np.ndarray, subset: np.ndarray, oriImg: np.ndarray): + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + for person in subset.astype(int): + has_left = np.sum(person[[5, 6, 7]] == -1) == 0 + has_right = np.sum(person[[2, 3, 4]] == -1) == 0 + if not (has_left or has_right): + continue + hands = [] + # left hand + if has_left: + left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]] + x1, y1 = candidate[left_shoulder_index][:2] + x2, y2 = candidate[left_elbow_index][:2] + x3, y3 = candidate[left_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, True]) + # right hand + if has_right: + right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]] + x1, y1 = candidate[right_shoulder_index][:2] + x2, y2 = candidate[right_elbow_index][:2] + x3, y3 = candidate[right_wrist_index][:2] + hands.append([x1, y1, x2, y2, x3, y3, False]) + + for x1, y1, x2, y2, x3, y3, is_left in hands: + + x = x3 + self.ratioWristElbow * (x3 - x2) + y = y3 + self.ratioWristElbow * (y3 - y2) + distanceWristElbow = math.sqrt((x3 - x2)**2 + (y3 - y2)**2) + distanceElbowShoulder = math.sqrt((x2 - x1)**2 + (y2 - y1)**2) + width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) + + x -= width / 2 + y -= width / 2 + + if x < 0: x = 0 + if y < 0: y = 0 + width1 = width + width2 = width + if x + width > image_width: width1 = image_width - x + if y + width > image_height: width2 = image_height - y + width = min(width1, width2) + + if width >= 20: + detect_result.append([int(x), int(y), int(width), is_left]) + + return detect_result