未验证 提交 42e487ba 编写于 作者: H houj04 提交者: GitHub

xpu support for fom model. (#647)

上级 ba4d0651
......@@ -51,7 +51,12 @@ parser.add_argument("--best_frame",
type=int,
default=None,
help="Set frame to start from.")
parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
# for device
group = parser.add_mutually_exclusive_group()
group.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
group.add_argument("--xpu", dest="xpu", action="store_true", help="xpu mode.")
parser.add_argument("--ratio",
dest="ratio",
type=float,
......@@ -78,26 +83,35 @@ parser.add_argument("--batch_size",
type=int,
default=1,
help="Batch size for fom model")
parser.add_argument(
"--face_enhancement",
dest="face_enhancement",
action="store_true",
help="use face enhance for face")
parser.add_argument(
"--mobile_net",
dest="mobile_net",
action="store_true",
help="use mobile_net for fom")
parser.add_argument("--face_enhancement",
dest="face_enhancement",
action="store_true",
help="use face enhance for face")
parser.add_argument("--mobile_net",
dest="mobile_net",
action="store_true",
help="use mobile_net for fom")
parser.set_defaults(relative=False)
parser.set_defaults(adapt_scale=False)
parser.set_defaults(face_enhancement=False)
parser.set_defaults(mobile_net=False)
parser.add_argument(
"--slice_size",
dest="slice_size",
type=int,
default=0,
help=
"slice driving video to smaller parts to bypass XPU's 4G byte tensor restriction"
)
if __name__ == "__main__":
args = parser.parse_args()
if args.cpu:
paddle.set_device('cpu')
if args.xpu:
paddle.set_device('xpu')
predictor = FirstOrderPredictor(output=args.output,
filename=args.filename,
weight_path=args.weight_path,
......@@ -112,6 +126,6 @@ if __name__ == "__main__":
image_size=args.image_size,
batch_size=args.batch_size,
face_enhancement=args.face_enhancement,
mobile_net=args.mobile_net)
mobile_net=args.mobile_net,
slice_size=args.slice_size)
predictor.run(args.source_image, args.driving_video)
......@@ -35,6 +35,7 @@ from .base_predictor import BasePredictor
class FirstOrderPredictor(BasePredictor):
def __init__(self,
output='output',
weight_path=None,
......@@ -50,7 +51,8 @@ class FirstOrderPredictor(BasePredictor):
image_size=256,
face_enhancement=False,
batch_size=1,
mobile_net=False):
mobile_net=False,
slice_size=0):
if config is not None and isinstance(config, str):
with open(config) as f:
self.cfg = yaml.load(f, Loader=yaml.SafeLoader)
......@@ -92,7 +94,7 @@ class FirstOrderPredictor(BasePredictor):
if weight_path is None:
if mobile_net:
vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-mobile.pdparams'
else:
if self.image_size == 512:
vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-cpk-512.pdparams'
......@@ -119,6 +121,7 @@ class FirstOrderPredictor(BasePredictor):
if face_enhancement:
from ppgan.faceutils.face_enhancement import FaceEnhancement
self.faceenhancer = FaceEnhancement(batch_size=batch_size)
self.slice_size = slice_size
def read_img(self, path):
img = imageio.imread(path)
......@@ -126,10 +129,11 @@ class FirstOrderPredictor(BasePredictor):
img = np.expand_dims(img, axis=2)
# som images have 4 channels
if img.shape[2] > 3:
img = img[:,:,:3]
img = img[:, :, :3]
return img
def run(self, source_image, driving_video):
def get_prediction(face_image):
if self.find_best_frame or self.best_frame is not None:
i = self.best_frame if self.best_frame is not None else self.find_best_frame_func(
......@@ -177,7 +181,8 @@ class FirstOrderPredictor(BasePredictor):
reader.close()
driving_video = [
cv2.resize(frame, (self.image_size, self.image_size)) / 255.0 for frame in driving_video
cv2.resize(frame, (self.image_size, self.image_size)) / 255.0
for frame in driving_video
]
results = []
......@@ -187,11 +192,17 @@ class FirstOrderPredictor(BasePredictor):
# for multi person
for rec in bboxes:
face_image = source_image.copy()[rec[1]:rec[3], rec[0]:rec[2]]
face_image = cv2.resize(face_image, (self.image_size, self.image_size)) / 255.0
face_image = cv2.resize(face_image,
(self.image_size, self.image_size)) / 255.0
predictions = get_prediction(face_image)
results.append({'rec': rec, 'predict': [predictions[i] for i in range(predictions.shape[0])]})
results.append({
'rec':
rec,
'predict':
[predictions[i] for i in range(predictions.shape[0])]
})
if len(bboxes) == 1 or not self.multi_person:
break
break
out_frame = []
for i in range(len(driving_video)):
......@@ -222,9 +233,10 @@ class FirstOrderPredictor(BasePredictor):
def load_checkpoints(self, config, checkpoint_path):
generator = OcclusionAwareGenerator(
**config['model']['generator']['generator_cfg'],
**config['model']['common_params'], inference=True)
generator = OcclusionAwareGenerator(**config['model']['generator']
['generator_cfg'],
**config['model']['common_params'],
inference=True)
kp_detector = KPDetector(
**config['model']['generator']['kp_detector_cfg'],
......@@ -252,24 +264,61 @@ class FirstOrderPredictor(BasePredictor):
source = paddle.to_tensor(source_image[np.newaxis].astype(
np.float32)).transpose([0, 3, 1, 2])
driving = paddle.to_tensor(
np.array(driving_video).astype(
np.float32)).transpose([0, 3, 1, 2])
driving_video_np = np.array(driving_video).astype(np.float32)
driving_n, driving_h, driving_w, driving_c = driving_video_np.shape
driving_slices = []
if self.slice_size != 0:
batch_count_in_slice = int(
np.floor(
float(self.slice_size) /
(self.batch_size * driving_h * driving_w * driving_c)))
assert batch_count_in_slice > 0, "batch_count_in_slice is 0, use smaller batch_size or bigger slice_size"
frame_count_in_slice = batch_count_in_slice * self.batch_size
for slice_start in range(0, driving_n, frame_count_in_slice):
slice_end = slice_start + min(frame_count_in_slice,
driving_n - slice_start)
current_slice = paddle.to_tensor(
driving_video_np[slice_start:slice_end, ]).transpose(
[0, 3, 1, 2])
driving_slices.append(current_slice)
else:
# whole driving as a single slice
driving = paddle.to_tensor(
np.array(driving_video).astype(np.float32)).transpose(
[0, 3, 1, 2])
frame_count_in_slice = driving_n
driving_slices.append(driving)
kp_source = kp_detector(source)
kp_driving_initial = kp_detector(driving[0:1])
kp_driving_initial = kp_detector(driving_slices[0][0:1])
kp_source_batch = {}
kp_source_batch["value"] = paddle.tile(kp_source["value"], repeat_times=[self.batch_size,1,1])
kp_source_batch["jacobian"] = paddle.tile(kp_source["jacobian"], repeat_times=[self.batch_size,1,1,1])
source = paddle.tile(source, repeat_times=[self.batch_size,1,1,1])
kp_source_batch["value"] = paddle.tile(
kp_source["value"], repeat_times=[self.batch_size, 1, 1])
kp_source_batch["jacobian"] = paddle.tile(
kp_source["jacobian"], repeat_times=[self.batch_size, 1, 1, 1])
source = paddle.tile(source,
repeat_times=[self.batch_size, 1, 1, 1])
begin_idx = 0
for frame_idx in tqdm(range(int(np.ceil(float(driving.shape[0]) / self.batch_size)))):
frame_num = min(self.batch_size, driving.shape[0] - begin_idx)
driving_frame = driving[begin_idx: begin_idx+frame_num]
for frame_idx in tqdm(
range(int(np.ceil(float(driving_n) / self.batch_size)))):
frame_num = min(self.batch_size, driving_n - begin_idx)
slice_id = int(frame_idx * self.batch_size /
frame_count_in_slice)
internal_start = frame_idx - slice_id * frame_count_in_slice
internal_end = frame_idx - slice_id * frame_count_in_slice + frame_num
driving_frame = driving_slices[slice_id][
internal_start:internal_end]
kp_driving = kp_detector(driving_frame)
kp_source_img = {}
kp_source_img["value"] = kp_source_batch["value"][0:frame_num]
kp_source_img["jacobian"] = kp_source_batch["jacobian"][0:frame_num]
kp_source_img["jacobian"] = kp_source_batch["jacobian"][
0:frame_num]
kp_norm = normalize_kp(
kp_source=kp_source,
kp_driving=kp_driving,
......@@ -277,10 +326,13 @@ class FirstOrderPredictor(BasePredictor):
use_relative_movement=relative,
use_relative_jacobian=relative,
adapt_movement_scale=adapt_movement_scale)
out = generator(source[0:frame_num], kp_source=kp_source_img, kp_driving=kp_norm)
img = np.transpose(out['prediction'].numpy(), [0, 2, 3, 1]) * 255.0
out = generator(source[0:frame_num],
kp_source=kp_source_img,
kp_driving=kp_norm)
img = np.transpose(out['prediction'].numpy(),
[0, 2, 3, 1]) * 255.0
if self.face_enhancement:
img = self.faceenhancer.enhance_from_batch(img)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册