From 0e958957c11d7b34f5b67658996022424c28b61f Mon Sep 17 00:00:00 2001 From: haoyuying <18844182690@163.com> Date: Mon, 12 Oct 2020 20:16:13 +0800 Subject: [PATCH] revise model.eval() --- demo/key_point_detection/openpose_body/predict.py | 1 - demo/key_point_detection/openpose_hands/predict.py | 1 - .../image/keypoint_detection/openpose_body_estimation/module.py | 1 + .../image/keypoint_detection/openpose_hands_estimation/module.py | 1 + 4 files changed, 2 insertions(+), 2 deletions(-) diff --git a/demo/key_point_detection/openpose_body/predict.py b/demo/key_point_detection/openpose_body/predict.py index 094e85d7..0834899c 100644 --- a/demo/key_point_detection/openpose_body/predict.py +++ b/demo/key_point_detection/openpose_body/predict.py @@ -5,5 +5,4 @@ if __name__ == "__main__": paddle.disable_static() model = hub.Module(name='openpose_body_estimation') - model.eval() out1, out2 = model.predict("demo.jpg") diff --git a/demo/key_point_detection/openpose_hands/predict.py b/demo/key_point_detection/openpose_hands/predict.py index 8c792dbb..78a37389 100644 --- a/demo/key_point_detection/openpose_hands/predict.py +++ b/demo/key_point_detection/openpose_hands/predict.py @@ -5,5 +5,4 @@ if __name__ == "__main__": paddle.disable_static() model = hub.Module(name='openpose_hands_estimation') - model.eval() all_hand_peaks = model.predict("demo.jpg") diff --git a/hub_module/modules/image/keypoint_detection/openpose_body_estimation/module.py b/hub_module/modules/image/keypoint_detection/openpose_body_estimation/module.py index 011506d2..0b1f9f32 100644 --- a/hub_module/modules/image/keypoint_detection/openpose_body_estimation/module.py +++ b/hub_module/modules/image/keypoint_detection/openpose_body_estimation/module.py @@ -180,6 +180,7 @@ class BodyPoseModel(nn.Layer): return out6_1, out6_2 def predict(self, img_path: str, save_path: str = "result"): + self.eval() orgImg = cv2.imread(img_path) data, imageToTest_padded, pad = self.transform(orgImg) Mconv7_stage6_L1, Mconv7_stage6_L2 = self.forward(paddle.to_tensor(data)) diff --git a/hub_module/modules/image/keypoint_detection/openpose_hands_estimation/module.py b/hub_module/modules/image/keypoint_detection/openpose_hands_estimation/module.py index d3fb57f0..cc983d8b 100644 --- a/hub_module/modules/image/keypoint_detection/openpose_hands_estimation/module.py +++ b/hub_module/modules/image/keypoint_detection/openpose_hands_estimation/module.py @@ -170,6 +170,7 @@ class HandPoseModel(nn.Layer): return np.array(all_peaks) def predict(self, img_path: str, save_path: str = 'result', scale: list = [0.5, 1.0, 1.5, 2.0]): + self.eval() self.body_model = hub.Module(name='openpose_body_estimation') self.body_model.eval() org_img = cv2.imread(img_path) -- GitLab