diff --git a/inference.py b/inference.py index 0a44c412192e135d5d5777da53548636e366a78b..3cd54ad074a444467de52b19e6b8716dcfbefa53 100644 --- a/inference.py +++ b/inference.py @@ -32,7 +32,7 @@ from hand_data_iter.datasets import draw_bd_handpose if __name__ == "__main__": parser = argparse.ArgumentParser(description=' Project Hand Pose Inference') - parser.add_argument('--model_path', type=str, default = './weights/ReXNetV1-size-256-wingloss102-0.122.pth', + parser.add_argument('--model_path', type=str, default = './weights/ReXNetV1-size-256-loss-adaptive_wing_loss-model_epoch-190.pth', help = 'model_path') # 模型路径 parser.add_argument('--model', type=str, default = 'ReXNetV1', help = '''model : resnet_34,resnet_50,resnet_101,squeezenet1_0,squeezenet1_1,shufflenetv2,shufflenet,mobilenetv2 @@ -91,7 +91,7 @@ if __name__ == "__main__": elif ops.model == "mobilenetv2": model_ = MobileNetV2(num_classes=ops.num_classes) elif ops.model == "ReXNetV1": - model_ = ReXNetV1( width_mult=0.9, depth_mult=1.0, num_classes=ops.num_classes) + model_ = ReXNetV1( width_mult=1.0, depth_mult=1.0, num_classes=ops.num_classes) use_cuda = torch.cuda.is_available()