未验证 提交 f0ae1b64 编写于 作者: Z zhiboniu 提交者: GitHub

pose3d part3: train and metric (#6613)

* pose3d train

* delete extra comments

* only pose3d parallel eval; delete pose3d config

* fix mistake edit
上级 c8f5e3be
......@@ -160,7 +160,7 @@ class Checkpointer(Callback):
def __init__(self, model):
super(Checkpointer, self).__init__(model)
cfg = self.model.cfg
self.best_ap = 0.
self.best_ap = -1000.
self.save_dir = os.path.join(self.model.cfg.save_dir,
self.model.cfg.filename)
if hasattr(self.model.model, 'student_model'):
......@@ -187,7 +187,11 @@ class Checkpointer(Callback):
if 'save_best_model' in status and status['save_best_model']:
for metric in self.model._metrics:
map_res = metric.get_results()
if 'bbox' in map_res:
eval_func = "ap"
if 'pose3d' in map_res:
key = 'pose3d'
eval_func = "mpjpe"
elif 'bbox' in map_res:
key = 'bbox'
elif 'keypoint' in map_res:
key = 'keypoint'
......@@ -202,8 +206,8 @@ class Checkpointer(Callback):
self.best_ap = map_res[key][0]
save_name = 'best_model'
weight = self.weight.state_dict()
logger.info("Best test {} ap is {:0.3f}.".format(
key, self.best_ap))
logger.info("Best test {} {} is {:0.3f}.".format(
key, eval_func, abs(self.best_ap)))
if weight:
if self.model.use_ema:
# save model and ema_model
......@@ -288,6 +292,7 @@ class VisualDLWriter(Callback):
self.vdl_mAP_step)
self.vdl_mAP_step += 1
class WandbCallback(Callback):
def __init__(self, model):
super(WandbCallback, self).__init__(model)
......@@ -307,10 +312,8 @@ class WandbCallback(Callback):
self.wandb_params = {}
for k, v in model.cfg.items():
if k.startswith("wandb_"):
self.wandb_params.update({
k.lstrip("wandb_"): v
})
self.wandb_params.update({k.lstrip("wandb_"): v})
self._run = None
if dist.get_world_size() < 2 or dist.get_rank() == 0:
_ = self.run
......@@ -318,28 +321,29 @@ class WandbCallback(Callback):
self.run.define_metric("epoch")
self.run.define_metric("eval/*", step_metric="epoch")
self.best_ap = 0
self.best_ap = -1000.
@property
def run(self):
if self._run is None:
if self.wandb.run is not None:
logger.info("There is an ongoing wandb run which will be used"
"for logging. Please use `wandb.finish()` to end that"
"if the behaviour is not intended")
logger.info(
"There is an ongoing wandb run which will be used"
"for logging. Please use `wandb.finish()` to end that"
"if the behaviour is not intended")
self._run = self.wandb.run
else:
self._run = self.wandb.init(**self.wandb_params)
return self._run
def save_model(self,
optimizer,
save_dir,
save_name,
last_epoch,
ema_model=None,
ap=None,
tags=None):
optimizer,
save_dir,
save_name,
last_epoch,
ema_model=None,
ap=None,
tags=None):
if dist.get_world_size() < 2 or dist.get_rank() == 0:
model_path = os.path.join(save_dir, save_name)
metadata = {}
......@@ -347,8 +351,14 @@ class WandbCallback(Callback):
if ap:
metadata["ap"] = ap
if ema_model is None:
ema_artifact = self.wandb.Artifact(name="ema_model-{}".format(self.run.id), type="model", metadata=metadata)
model_artifact = self.wandb.Artifact(name="model-{}".format(self.run.id), type="model", metadata=metadata)
ema_artifact = self.wandb.Artifact(
name="ema_model-{}".format(self.run.id),
type="model",
metadata=metadata)
model_artifact = self.wandb.Artifact(
name="model-{}".format(self.run.id),
type="model",
metadata=metadata)
ema_artifact.add_file(model_path + ".pdema", name="model_ema")
model_artifact.add_file(model_path + ".pdparams", name="model")
......@@ -356,10 +366,13 @@ class WandbCallback(Callback):
self.run.log_artifact(ema_artifact, aliases=tags)
self.run.log_artfact(model_artifact, aliases=tags)
else:
model_artifact = self.wandb.Artifact(name="model-{}".format(self.run.id), type="model", metadata=metadata)
model_artifact = self.wandb.Artifact(
name="model-{}".format(self.run.id),
type="model",
metadata=metadata)
model_artifact.add_file(model_path + ".pdparams", name="model")
self.run.log_artifact(model_artifact, aliases=tags)
def on_step_end(self, status):
mode = status['mode']
......@@ -368,11 +381,9 @@ class WandbCallback(Callback):
training_status = status['training_staus'].get()
for k, v in training_status.items():
training_status[k] = float(v)
metrics = {
"train/" + k: v for k,v in training_status.items()
}
metrics = {"train/" + k: v for k, v in training_status.items()}
self.run.log(metrics)
def on_epoch_end(self, status):
mode = status['mode']
epoch_id = status['epoch_id']
......@@ -383,7 +394,8 @@ class WandbCallback(Callback):
if (
epoch_id + 1
) % self.model.cfg.snapshot_epoch == 0 or epoch_id == end_epoch - 1:
save_name = str(epoch_id) if epoch_id != end_epoch - 1 else "model_final"
save_name = str(
epoch_id) if epoch_id != end_epoch - 1 else "model_final"
tags = ["latest", "epoch_{}".format(epoch_id)]
self.save_model(
self.model.optimizer,
......@@ -391,8 +403,7 @@ class WandbCallback(Callback):
save_name,
epoch_id + 1,
self.model.use_ema,
tags=tags
)
tags=tags)
if mode == 'eval':
merged_dict = {}
for metric in self.model._metrics:
......@@ -404,7 +415,9 @@ class WandbCallback(Callback):
if 'save_best_model' in status and status['save_best_model']:
for metric in self.model._metrics:
map_res = metric.get_results()
if 'bbox' in map_res:
if 'pose3d' in map_res:
key = 'pose3d'
elif 'bbox' in map_res:
key = 'bbox'
elif 'keypoint' in map_res:
key = 'keypoint'
......@@ -426,10 +439,9 @@ class WandbCallback(Callback):
save_name,
last_epoch=epoch_id + 1,
ema_model=self.model.use_ema,
ap=self.best_ap,
tags=tags
)
ap=abs(self.best_ap),
tags=tags)
def on_train_end(self, status):
self.run.finish()
......
......@@ -49,6 +49,7 @@ TRT_MIN_SUBGRAPH = {
'CenterNet': 5,
'TOOD': 5,
'YOLOX': 8,
'METRO_Body': 3,
}
KEYPOINT_ARCH = ['HigherHRNet', 'TopDownHRNet']
......
......@@ -38,7 +38,7 @@ from ppdet.optimizer import ModelEMA
from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
from ppdet.utils.visualizer import visualize_results, save_result
from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownMPIIEval
from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownMPIIEval, Pose3DEval
from ppdet.metrics import RBoxMetric, JDEDetMetric, SNIPERCOCOMetric
from ppdet.data.source.sniper_coco import SniperCOCODataSet
from ppdet.data.source.category import get_categories
......@@ -136,6 +136,9 @@ class Trainer(object):
if self.mode == 'eval':
if cfg.architecture == 'FairMOT':
self.loader = create('EvalMOTReader')(self.dataset, 0)
elif cfg.architecture == "METRO_Body":
reader_name = '{}Reader'.format(self.mode.capitalize())
self.loader = create(reader_name)(self.dataset, cfg.worker_num)
else:
self._eval_batch_sampler = paddle.io.BatchSampler(
self.dataset, batch_size=self.cfg.EvalReader['batch_size'])
......@@ -342,6 +345,13 @@ class Trainer(object):
self.cfg.save_dir,
save_prediction_only=save_prediction_only)
]
elif self.cfg.metric == 'Pose3DEval':
save_prediction_only = self.cfg.get('save_prediction_only', False)
self._metrics = [
Pose3DEval(
self.cfg.save_dir,
save_prediction_only=save_prediction_only)
]
elif self.cfg.metric == 'MOTDet':
self._metrics = [JDEDetMetric(), ]
else:
......@@ -450,6 +460,7 @@ class Trainer(object):
self.loader.dataset.set_epoch(epoch_id)
model.train()
iter_tic = time.time()
print("loader len:", len(self.loader))
for step_id, data in enumerate(self.loader):
self.status['data_time'].update(time.time() - iter_tic)
self.status['step_id'] = step_id
......@@ -537,7 +548,7 @@ class Trainer(object):
self._compose_callback.on_epoch_end(self.status)
if validate and is_snapshot:
if validate:
if not hasattr(self, '_eval_loader'):
# build evaluation dataset and loader
self._eval_dataset = self.cfg.EvalDataset
......@@ -548,10 +559,14 @@ class Trainer(object):
# If metric is VOC, need to be set collate_batch=False.
if self.cfg.metric == 'VOC':
self.cfg['EvalReader']['collate_batch'] = False
self._eval_loader = create('EvalReader')(
self._eval_dataset,
self.cfg.worker_num,
batch_sampler=self._eval_batch_sampler)
if self.cfg.metric == "Pose3DEval":
self._eval_loader = create('EvalReader')(
self._eval_dataset, self.cfg.worker_num)
else:
self._eval_loader = create('EvalReader')(
self._eval_dataset,
self.cfg.worker_num,
batch_sampler=self._eval_batch_sampler)
# if validation in training is enabled, metrics should be re-init
# Init_mark makes sure this code will only execute once
if validate and Init_mark == False:
......@@ -575,6 +590,7 @@ class Trainer(object):
tic = time.time()
self._compose_callback.on_epoch_begin(self.status)
self.status['mode'] = 'eval'
self.model.eval()
if self.cfg.get('print_flops', False):
flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
......@@ -617,6 +633,15 @@ class Trainer(object):
self._reset_metrics()
def evaluate(self):
# get distributed model
if self.cfg.get('fleet', False):
self.model = fleet.distributed_model(self.model)
self.optimizer = fleet.distributed_optimizer(self.optimizer)
elif self._nranks > 1:
find_unused_parameters = self.cfg[
'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False
self.model = paddle.DataParallel(
self.model, find_unused_parameters=find_unused_parameters)
with paddle.no_grad():
self._eval_with_loader(self.loader)
......@@ -921,9 +946,11 @@ class Trainer(object):
if 'segm' in batch_res else None
keypoint_res = batch_res['keypoint'][start:end] \
if 'keypoint' in batch_res else None
pose3d_res = batch_res['pose3d'][start:end] \
if 'pose3d' in batch_res else None
image = visualize_results(
image, bbox_res, mask_res, segm_res, keypoint_res,
int(im_id), catid2name, draw_threshold)
pose3d_res, int(im_id), catid2name, draw_threshold)
self.status['result_image'] = np.array(image.copy())
if self._compose_callback:
self._compose_callback.on_step_end(self.status)
......
......@@ -17,6 +17,7 @@ from . import keypoint_metrics
from .metrics import *
from .keypoint_metrics import *
from .pose3d_metrics import *
__all__ = metrics.__all__ + keypoint_metrics.__all__
......
......@@ -21,7 +21,7 @@ import sys
import numpy as np
import itertools
from ppdet.metrics.json_results import get_det_res, get_det_poly_res, get_seg_res, get_solov2_segm_res, get_keypoint_res
from ppdet.metrics.json_results import get_det_res, get_det_poly_res, get_seg_res, get_solov2_segm_res, get_keypoint_res, get_pose3d_res
from ppdet.metrics.map_utils import draw_pr_curve
from ppdet.utils.logger import setup_logger
......@@ -64,6 +64,10 @@ def get_infer_results(outs, catid, bias=0):
infer_res['keypoint'] = get_keypoint_res(outs, im_id)
outs['bbox_num'] = [len(infer_res['keypoint'])]
if 'pose3d' in outs:
infer_res['pose3d'] = get_pose3d_res(outs, im_id)
outs['bbox_num'] = [len(infer_res['pose3d'])]
return infer_res
......@@ -150,7 +154,7 @@ def cocoapi_eval(jsonfile,
results_flatten = list(itertools.chain(*results_per_category))
headers = ['category', 'AP'] * (num_columns // 2)
results_2d = itertools.zip_longest(
*[results_flatten[i::num_columns] for i in range(num_columns)])
* [results_flatten[i::num_columns] for i in range(num_columns)])
table_data = [headers]
table_data += [result for result in results_2d]
table = AsciiTable(table_data)
......
......@@ -157,3 +157,19 @@ def get_keypoint_res(results, im_id):
ann['bbox'] = [x0, y0, x1 - x0, y1 - y0]
anns.append(ann)
return anns
def get_pose3d_res(results, im_id):
anns = []
preds = results['pose3d']
for idx in range(im_id.shape[0]):
image_id = im_id[idx].item()
pose3d = preds[idx]
ann = {
'image_id': image_id,
'category_id': 1, # XXX hard code
'pose3d': pose3d.tolist(),
'score': float(1.)
}
anns.append(ann)
return anns
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import os
import json
from collections import defaultdict, OrderedDict
import numpy as np
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
__all__ = ['Pose3DEval']
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def mean_per_joint_position_error(pred, gt, has_3d_joints):
"""
Compute mPJPE
"""
gt = gt[has_3d_joints == 1]
gt = gt[:, :, :3]
pred = pred[has_3d_joints == 1]
with paddle.no_grad():
gt_pelvis = (gt[:, 2, :] + gt[:, 3, :]) / 2
gt = gt - gt_pelvis[:, None, :]
pred_pelvis = (pred[:, 2, :] + pred[:, 3, :]) / 2
pred = pred - pred_pelvis[:, None, :]
error = paddle.sqrt(((pred - gt)**2).sum(axis=-1)).mean(axis=-1).numpy()
return error
def compute_similarity_transform(S1, S2):
"""Computes a similarity transform (sR, t) that takes
a set of 3D points S1 (3 x N) closest to a set of 3D points S2,
where R is an 3x3 rotation matrix, t 3x1 translation, s scale.
i.e. solves the orthogonal Procrutes problem.
"""
transposed = False
if S1.shape[0] != 3 and S1.shape[0] != 2:
S1 = S1.T
S2 = S2.T
transposed = True
assert (S2.shape[1] == S1.shape[1])
# 1. Remove mean.
mu1 = S1.mean(axis=1, keepdims=True)
mu2 = S2.mean(axis=1, keepdims=True)
X1 = S1 - mu1
X2 = S2 - mu2
# 2. Compute variance of X1 used for scale.
var1 = np.sum(X1**2)
# 3. The outer product of X1 and X2.
K = X1.dot(X2.T)
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
# singular vectors of K.
U, s, Vh = np.linalg.svd(K)
V = Vh.T
# Construct Z that fixes the orientation of R to get det(R)=1.
Z = np.eye(U.shape[0])
Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T)))
# Construct R.
R = V.dot(Z.dot(U.T))
# 5. Recover scale.
scale = np.trace(R.dot(K)) / var1
# 6. Recover translation.
t = mu2 - scale * (R.dot(mu1))
# 7. Error:
S1_hat = scale * R.dot(S1) + t
if transposed:
S1_hat = S1_hat.T
return S1_hat
def compute_similarity_transform_batch(S1, S2):
"""Batched version of compute_similarity_transform."""
S1_hat = np.zeros_like(S1)
for i in range(S1.shape[0]):
S1_hat[i] = compute_similarity_transform(S1[i], S2[i])
return S1_hat
def reconstruction_error(S1, S2, reduction='mean'):
"""Do Procrustes alignment and compute reconstruction error."""
S1_hat = compute_similarity_transform_batch(S1, S2)
re = np.sqrt(((S1_hat - S2)**2).sum(axis=-1)).mean(axis=-1)
if reduction == 'mean':
re = re.mean()
elif reduction == 'sum':
re = re.sum()
return re
def all_gather(data):
if paddle.distributed.get_world_size() == 1:
return data
vlist = []
paddle.distributed.all_gather(vlist, data)
data = paddle.concat(vlist, 0)
return data
class Pose3DEval(object):
"""refer to
https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
Copyright (c) Microsoft, under the MIT License.
"""
def __init__(self, output_eval, save_prediction_only=False):
super(Pose3DEval, self).__init__()
self.output_eval = output_eval
self.res_file = os.path.join(output_eval, "pose3d_results.json")
self.save_prediction_only = save_prediction_only
self.reset()
def reset(self):
self.PAmPJPE = AverageMeter()
self.mPJPE = AverageMeter()
self.eval_results = {}
def get_human36m_joints(self, input):
J24_TO_J14 = paddle.to_tensor(
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18])
J24_TO_J17 = paddle.to_tensor(
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 18, 19])
return paddle.index_select(input, J24_TO_J14, axis=1)
def update(self, inputs, outputs):
gt_3d_joints = all_gather(inputs['joints_3d'])
has_3d_joints = all_gather(inputs['has_3d_joints'])
pred_3d_joints = all_gather(outputs['pose3d'])
if gt_3d_joints.shape[1] == 24:
gt_3d_joints = self.get_human36m_joints(gt_3d_joints)
if pred_3d_joints.shape[1] == 24:
pred_3d_joints = self.get_human36m_joints(pred_3d_joints)
mPJPE_val = mean_per_joint_position_error(pred_3d_joints, gt_3d_joints,
has_3d_joints).mean()
PAmPJPE_val = reconstruction_error(
pred_3d_joints.numpy(),
gt_3d_joints[:, :, :3].numpy(),
reduction=None).mean()
count = int(np.sum(has_3d_joints.numpy()))
self.PAmPJPE.update(PAmPJPE_val * 1000., count)
self.mPJPE.update(mPJPE_val * 1000., count)
def accumulate(self):
if self.save_prediction_only:
logger.info(f'The pose3d result is saved to {self.res_file} '
'and do not evaluate the model.')
return
self.eval_results['pose3d'] = [-self.mPJPE.avg, -self.PAmPJPE.avg]
def log(self):
if self.save_prediction_only:
return
stats_names = ['mPJPE', 'PAmPJPE']
num_values = len(stats_names)
print(' '.join(['| {}'.format(name) for name in stats_names]) + ' |')
print('|---' * (num_values + 1) + '|')
print(' '.join([
'| {:.3f}'.format(abs(value))
for value in self.eval_results['pose3d']
]) + ' |')
def get_results(self):
return self.eval_results
......@@ -34,6 +34,7 @@ def visualize_results(image,
mask_res,
segm_res,
keypoint_res,
pose3d_res,
im_id,
catid2name,
threshold=0.5):
......@@ -48,6 +49,8 @@ def visualize_results(image,
image = draw_segm(image, im_id, catid2name, segm_res, threshold)
if keypoint_res is not None:
image = draw_pose(image, keypoint_res, threshold)
if pose3d_res is not None:
image = draw_pose3d(image, pose3d_res, threshold)
return image
......@@ -319,3 +322,134 @@ def draw_pose(image,
image = Image.fromarray(canvas.astype('uint8'))
plt.close()
return image
def draw_pose3d(image,
results,
visual_thread=0.6,
save_name='pose3d.jpg',
save_dir='output',
returnimg=False,
ids=None):
try:
import matplotlib.pyplot as plt
import matplotlib
plt.switch_backend('agg')
except Exception as e:
logger.error('Matplotlib not found, please install matplotlib.'
'for example: `pip install matplotlib`.')
raise e
pose3d = np.array(results[0]['pose3d']) * 1000
if pose3d.shape[0] == 24:
joints_connectivity_dict = [
[0, 1, 0], [1, 2, 0], [5, 4, 1], [4, 3, 1], [2, 3, 0], [2, 14, 1],
[3, 14, 1], [14, 15, 1], [15, 16, 1], [16, 12, 1], [6, 7, 0],
[7, 8, 0], [11, 10, 1], [10, 9, 1], [8, 12, 0], [9, 12, 1],
[12, 19, 1], [19, 18, 1], [19, 20, 0], [19, 21, 1], [22, 20, 0],
[23, 21, 1]
]
elif pose3d.shape[0] == 14:
joints_connectivity_dict = [
[0, 1, 0], [1, 2, 0], [5, 4, 1], [4, 3, 1], [2, 3, 0], [2, 12, 0],
[3, 12, 1], [6, 7, 0], [7, 8, 0], [11, 10, 1], [10, 9, 1],
[8, 12, 0], [9, 12, 1], [12, 13, 1]
]
else:
print(
"not defined joints number :{}, cannot visualize because unknown of joint connectivity".
format(pose.shape[0]))
return
def draw3Dpose(pose3d,
ax,
lcolor="#3498db",
rcolor="#e74c3c",
add_labels=False):
# pose3d = orthographic_projection(pose3d, cam)
for i in joints_connectivity_dict:
x, y, z = [
np.array([pose3d[i[0], j], pose3d[i[1], j]]) for j in range(3)
]
ax.plot(-x, -z, -y, lw=2, c=lcolor if i[2] else rcolor)
RADIUS = 1000
center_xy = 2 if pose3d.shape[0] == 14 else 14
x, y, z = pose3d[center_xy, 0], pose3d[center_xy, 1], pose3d[center_xy,
2]
ax.set_xlim3d([-RADIUS + x, RADIUS + x])
ax.set_ylim3d([-RADIUS + y, RADIUS + y])
ax.set_zlim3d([-RADIUS + z, RADIUS + z])
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")
def draw2Dpose(pose2d,
ax,
lcolor="#3498db",
rcolor="#e74c3c",
add_labels=False):
for i in joints_connectivity_dict:
if pose2d[i[0], 2] and pose2d[i[1], 2]:
x, y = [
np.array([pose2d[i[0], j], pose2d[i[1], j]])
for j in range(2)
]
ax.plot(x, y, 0, lw=2, c=lcolor if i[2] else rcolor)
def draw_img_pose(pose3d,
pose2d=None,
frame=None,
figsize=(12, 12),
savepath=None):
fig = plt.figure(figsize=figsize, dpi=80)
# fig.clear()
fig.tight_layout()
ax = fig.add_subplot(221)
if frame is not None:
ax.imshow(frame, interpolation='nearest')
if pose2d is not None:
draw2Dpose(pose2d, ax)
ax = fig.add_subplot(222, projection='3d')
ax.view_init(45, 45)
draw3Dpose(pose3d, ax)
ax = fig.add_subplot(223, projection='3d')
ax.view_init(0, 0)
draw3Dpose(pose3d, ax)
ax = fig.add_subplot(224, projection='3d')
ax.view_init(0, 90)
draw3Dpose(pose3d, ax)
if savepath is not None:
plt.savefig(savepath)
plt.close()
else:
return fig
def fig2data(fig):
"""
fig = plt.figure()
image = fig2data(fig)
@brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it
@param fig a matplotlib figure
@return a numpy 3D array of RGBA values
"""
# draw the renderer
fig.canvas.draw()
# Get the RGBA buffer from the figure
w, h = fig.canvas.get_width_height()
buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8)
buf.shape = (w, h, 4)
# canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode
buf = np.roll(buf, 3, axis=2)
image = Image.frombytes("RGBA", (w, h), buf.tostring())
return image.convert("RGB")
fig = draw_img_pose(pose3d, frame=image)
data = fig2data(fig)
return data
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册