未验证 提交 42e487ba 编写于 作者: H houj04 提交者: GitHub

xpu support for fom model. (#647)

上级 ba4d0651
...@@ -51,7 +51,12 @@ parser.add_argument("--best_frame", ...@@ -51,7 +51,12 @@ parser.add_argument("--best_frame",
type=int, type=int,
default=None, default=None,
help="Set frame to start from.") help="Set frame to start from.")
parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
# for device
group = parser.add_mutually_exclusive_group()
group.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.")
group.add_argument("--xpu", dest="xpu", action="store_true", help="xpu mode.")
parser.add_argument("--ratio", parser.add_argument("--ratio",
dest="ratio", dest="ratio",
type=float, type=float,
...@@ -78,26 +83,35 @@ parser.add_argument("--batch_size", ...@@ -78,26 +83,35 @@ parser.add_argument("--batch_size",
type=int, type=int,
default=1, default=1,
help="Batch size for fom model") help="Batch size for fom model")
parser.add_argument( parser.add_argument("--face_enhancement",
"--face_enhancement", dest="face_enhancement",
dest="face_enhancement", action="store_true",
action="store_true", help="use face enhance for face")
help="use face enhance for face") parser.add_argument("--mobile_net",
parser.add_argument( dest="mobile_net",
"--mobile_net", action="store_true",
dest="mobile_net", help="use mobile_net for fom")
action="store_true",
help="use mobile_net for fom")
parser.set_defaults(relative=False) parser.set_defaults(relative=False)
parser.set_defaults(adapt_scale=False) parser.set_defaults(adapt_scale=False)
parser.set_defaults(face_enhancement=False) parser.set_defaults(face_enhancement=False)
parser.set_defaults(mobile_net=False) parser.set_defaults(mobile_net=False)
parser.add_argument(
"--slice_size",
dest="slice_size",
type=int,
default=0,
help=
"slice driving video to smaller parts to bypass XPU's 4G byte tensor restriction"
)
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
if args.cpu: if args.cpu:
paddle.set_device('cpu') paddle.set_device('cpu')
if args.xpu:
paddle.set_device('xpu')
predictor = FirstOrderPredictor(output=args.output, predictor = FirstOrderPredictor(output=args.output,
filename=args.filename, filename=args.filename,
weight_path=args.weight_path, weight_path=args.weight_path,
...@@ -112,6 +126,6 @@ if __name__ == "__main__": ...@@ -112,6 +126,6 @@ if __name__ == "__main__":
image_size=args.image_size, image_size=args.image_size,
batch_size=args.batch_size, batch_size=args.batch_size,
face_enhancement=args.face_enhancement, face_enhancement=args.face_enhancement,
mobile_net=args.mobile_net) mobile_net=args.mobile_net,
slice_size=args.slice_size)
predictor.run(args.source_image, args.driving_video) predictor.run(args.source_image, args.driving_video)
...@@ -35,6 +35,7 @@ from .base_predictor import BasePredictor ...@@ -35,6 +35,7 @@ from .base_predictor import BasePredictor
class FirstOrderPredictor(BasePredictor): class FirstOrderPredictor(BasePredictor):
def __init__(self, def __init__(self,
output='output', output='output',
weight_path=None, weight_path=None,
...@@ -50,7 +51,8 @@ class FirstOrderPredictor(BasePredictor): ...@@ -50,7 +51,8 @@ class FirstOrderPredictor(BasePredictor):
image_size=256, image_size=256,
face_enhancement=False, face_enhancement=False,
batch_size=1, batch_size=1,
mobile_net=False): mobile_net=False,
slice_size=0):
if config is not None and isinstance(config, str): if config is not None and isinstance(config, str):
with open(config) as f: with open(config) as f:
self.cfg = yaml.load(f, Loader=yaml.SafeLoader) self.cfg = yaml.load(f, Loader=yaml.SafeLoader)
...@@ -92,7 +94,7 @@ class FirstOrderPredictor(BasePredictor): ...@@ -92,7 +94,7 @@ class FirstOrderPredictor(BasePredictor):
if weight_path is None: if weight_path is None:
if mobile_net: if mobile_net:
vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-mobile.pdparams' vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-mobile.pdparams'
else: else:
if self.image_size == 512: if self.image_size == 512:
vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-cpk-512.pdparams' vox_cpk_weight_url = 'https://paddlegan.bj.bcebos.com/applications/first_order_model/vox-cpk-512.pdparams'
...@@ -119,6 +121,7 @@ class FirstOrderPredictor(BasePredictor): ...@@ -119,6 +121,7 @@ class FirstOrderPredictor(BasePredictor):
if face_enhancement: if face_enhancement:
from ppgan.faceutils.face_enhancement import FaceEnhancement from ppgan.faceutils.face_enhancement import FaceEnhancement
self.faceenhancer = FaceEnhancement(batch_size=batch_size) self.faceenhancer = FaceEnhancement(batch_size=batch_size)
self.slice_size = slice_size
def read_img(self, path): def read_img(self, path):
img = imageio.imread(path) img = imageio.imread(path)
...@@ -126,10 +129,11 @@ class FirstOrderPredictor(BasePredictor): ...@@ -126,10 +129,11 @@ class FirstOrderPredictor(BasePredictor):
img = np.expand_dims(img, axis=2) img = np.expand_dims(img, axis=2)
# som images have 4 channels # som images have 4 channels
if img.shape[2] > 3: if img.shape[2] > 3:
img = img[:,:,:3] img = img[:, :, :3]
return img return img
def run(self, source_image, driving_video): def run(self, source_image, driving_video):
def get_prediction(face_image): def get_prediction(face_image):
if self.find_best_frame or self.best_frame is not None: if self.find_best_frame or self.best_frame is not None:
i = self.best_frame if self.best_frame is not None else self.find_best_frame_func( i = self.best_frame if self.best_frame is not None else self.find_best_frame_func(
...@@ -177,7 +181,8 @@ class FirstOrderPredictor(BasePredictor): ...@@ -177,7 +181,8 @@ class FirstOrderPredictor(BasePredictor):
reader.close() reader.close()
driving_video = [ driving_video = [
cv2.resize(frame, (self.image_size, self.image_size)) / 255.0 for frame in driving_video cv2.resize(frame, (self.image_size, self.image_size)) / 255.0
for frame in driving_video
] ]
results = [] results = []
...@@ -187,11 +192,17 @@ class FirstOrderPredictor(BasePredictor): ...@@ -187,11 +192,17 @@ class FirstOrderPredictor(BasePredictor):
# for multi person # for multi person
for rec in bboxes: for rec in bboxes:
face_image = source_image.copy()[rec[1]:rec[3], rec[0]:rec[2]] face_image = source_image.copy()[rec[1]:rec[3], rec[0]:rec[2]]
face_image = cv2.resize(face_image, (self.image_size, self.image_size)) / 255.0 face_image = cv2.resize(face_image,
(self.image_size, self.image_size)) / 255.0
predictions = get_prediction(face_image) predictions = get_prediction(face_image)
results.append({'rec': rec, 'predict': [predictions[i] for i in range(predictions.shape[0])]}) results.append({
'rec':
rec,
'predict':
[predictions[i] for i in range(predictions.shape[0])]
})
if len(bboxes) == 1 or not self.multi_person: if len(bboxes) == 1 or not self.multi_person:
break break
out_frame = [] out_frame = []
for i in range(len(driving_video)): for i in range(len(driving_video)):
...@@ -222,9 +233,10 @@ class FirstOrderPredictor(BasePredictor): ...@@ -222,9 +233,10 @@ class FirstOrderPredictor(BasePredictor):
def load_checkpoints(self, config, checkpoint_path): def load_checkpoints(self, config, checkpoint_path):
generator = OcclusionAwareGenerator( generator = OcclusionAwareGenerator(**config['model']['generator']
**config['model']['generator']['generator_cfg'], ['generator_cfg'],
**config['model']['common_params'], inference=True) **config['model']['common_params'],
inference=True)
kp_detector = KPDetector( kp_detector = KPDetector(
**config['model']['generator']['kp_detector_cfg'], **config['model']['generator']['kp_detector_cfg'],
...@@ -252,24 +264,61 @@ class FirstOrderPredictor(BasePredictor): ...@@ -252,24 +264,61 @@ class FirstOrderPredictor(BasePredictor):
source = paddle.to_tensor(source_image[np.newaxis].astype( source = paddle.to_tensor(source_image[np.newaxis].astype(
np.float32)).transpose([0, 3, 1, 2]) np.float32)).transpose([0, 3, 1, 2])
driving = paddle.to_tensor( driving_video_np = np.array(driving_video).astype(np.float32)
np.array(driving_video).astype( driving_n, driving_h, driving_w, driving_c = driving_video_np.shape
np.float32)).transpose([0, 3, 1, 2])
driving_slices = []
if self.slice_size != 0:
batch_count_in_slice = int(
np.floor(
float(self.slice_size) /
(self.batch_size * driving_h * driving_w * driving_c)))
assert batch_count_in_slice > 0, "batch_count_in_slice is 0, use smaller batch_size or bigger slice_size"
frame_count_in_slice = batch_count_in_slice * self.batch_size
for slice_start in range(0, driving_n, frame_count_in_slice):
slice_end = slice_start + min(frame_count_in_slice,
driving_n - slice_start)
current_slice = paddle.to_tensor(
driving_video_np[slice_start:slice_end, ]).transpose(
[0, 3, 1, 2])
driving_slices.append(current_slice)
else:
# whole driving as a single slice
driving = paddle.to_tensor(
np.array(driving_video).astype(np.float32)).transpose(
[0, 3, 1, 2])
frame_count_in_slice = driving_n
driving_slices.append(driving)
kp_source = kp_detector(source) kp_source = kp_detector(source)
kp_driving_initial = kp_detector(driving[0:1]) kp_driving_initial = kp_detector(driving_slices[0][0:1])
kp_source_batch = {} kp_source_batch = {}
kp_source_batch["value"] = paddle.tile(kp_source["value"], repeat_times=[self.batch_size,1,1]) kp_source_batch["value"] = paddle.tile(
kp_source_batch["jacobian"] = paddle.tile(kp_source["jacobian"], repeat_times=[self.batch_size,1,1,1]) kp_source["value"], repeat_times=[self.batch_size, 1, 1])
source = paddle.tile(source, repeat_times=[self.batch_size,1,1,1]) kp_source_batch["jacobian"] = paddle.tile(
kp_source["jacobian"], repeat_times=[self.batch_size, 1, 1, 1])
source = paddle.tile(source,
repeat_times=[self.batch_size, 1, 1, 1])
begin_idx = 0 begin_idx = 0
for frame_idx in tqdm(range(int(np.ceil(float(driving.shape[0]) / self.batch_size)))): for frame_idx in tqdm(
frame_num = min(self.batch_size, driving.shape[0] - begin_idx) range(int(np.ceil(float(driving_n) / self.batch_size)))):
driving_frame = driving[begin_idx: begin_idx+frame_num] frame_num = min(self.batch_size, driving_n - begin_idx)
slice_id = int(frame_idx * self.batch_size /
frame_count_in_slice)
internal_start = frame_idx - slice_id * frame_count_in_slice
internal_end = frame_idx - slice_id * frame_count_in_slice + frame_num
driving_frame = driving_slices[slice_id][
internal_start:internal_end]
kp_driving = kp_detector(driving_frame) kp_driving = kp_detector(driving_frame)
kp_source_img = {} kp_source_img = {}
kp_source_img["value"] = kp_source_batch["value"][0:frame_num] kp_source_img["value"] = kp_source_batch["value"][0:frame_num]
kp_source_img["jacobian"] = kp_source_batch["jacobian"][0:frame_num] kp_source_img["jacobian"] = kp_source_batch["jacobian"][
0:frame_num]
kp_norm = normalize_kp( kp_norm = normalize_kp(
kp_source=kp_source, kp_source=kp_source,
kp_driving=kp_driving, kp_driving=kp_driving,
...@@ -277,10 +326,13 @@ class FirstOrderPredictor(BasePredictor): ...@@ -277,10 +326,13 @@ class FirstOrderPredictor(BasePredictor):
use_relative_movement=relative, use_relative_movement=relative,
use_relative_jacobian=relative, use_relative_jacobian=relative,
adapt_movement_scale=adapt_movement_scale) adapt_movement_scale=adapt_movement_scale)
out = generator(source[0:frame_num], kp_source=kp_source_img, kp_driving=kp_norm) out = generator(source[0:frame_num],
img = np.transpose(out['prediction'].numpy(), [0, 2, 3, 1]) * 255.0 kp_source=kp_source_img,
kp_driving=kp_norm)
img = np.transpose(out['prediction'].numpy(),
[0, 2, 3, 1]) * 255.0
if self.face_enhancement: if self.face_enhancement:
img = self.faceenhancer.enhance_from_batch(img) img = self.faceenhancer.enhance_from_batch(img)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册