dataPartitionDialog.py 3.3 KB
Newer Older
W
new  
whj_dark 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
try:
    from PyQt5.QtGui import *
    from PyQt5.QtCore import *
    from PyQt5.QtWidgets import *
except ImportError:
    from PyQt4.QtGui import *
    from PyQt4.QtCore import *

from libs.utils import newIcon

import time
import datetime
import json
import cv2
import numpy as np


BB = QDialogButtonBox

class DataPartitionDialog(QDialog):
    def __init__(self, parent=None):
        super().__init__()
        self.parnet = parent
        self.title = 'DATA PARTITION'

        self.train_ratio = 70
        self.val_ratio = 15
        self.test_ratio = 15
        
        self.initUI()

    def initUI(self):
        self.setWindowTitle(self.title)
        self.setWindowModality(Qt.ApplicationModal)

        self.flag_accept = True

        if self.parnet.lang == 'ch':
            msg = "导出JSON前请保存所有图像的标注且关闭EXCEL!"
        else:
            msg = "Please save all the annotations and close the EXCEL before exporting JSON!"

        info_msg = QLabel(msg, self)
        info_msg.setWordWrap(True)
        info_msg.setStyleSheet("color: red")
        info_msg.setFont(QFont('Arial', 12))

        train_lbl = QLabel('Train split: ', self)
        train_lbl.setFont(QFont('Arial', 15))
        val_lbl = QLabel('Valid split: ', self)
        val_lbl.setFont(QFont('Arial', 15))
        test_lbl = QLabel('Test split: ', self)
        test_lbl.setFont(QFont('Arial', 15))

        self.train_input = QLineEdit(self)
        self.train_input.setFont(QFont('Arial', 15))
        self.val_input = QLineEdit(self)
        self.val_input.setFont(QFont('Arial', 15))
        self.test_input = QLineEdit(self)
        self.test_input.setFont(QFont('Arial', 15))

        self.train_input.setText(str(self.train_ratio))
        self.val_input.setText(str(self.val_ratio))
        self.test_input.setText(str(self.test_ratio))

        validator = QIntValidator(0, 100)
        self.train_input.setValidator(validator)
        self.val_input.setValidator(validator)
        self.test_input.setValidator(validator)

        gridlayout = QGridLayout()
        gridlayout.addWidget(info_msg, 0, 0, 1, 2)
        gridlayout.addWidget(train_lbl, 1, 0)
        gridlayout.addWidget(val_lbl, 2, 0)
        gridlayout.addWidget(test_lbl, 3, 0)
        gridlayout.addWidget(self.train_input, 1, 1)
        gridlayout.addWidget(self.val_input, 2, 1)
        gridlayout.addWidget(self.test_input, 3, 1)

        bb = BB(BB.Ok | BB.Cancel, Qt.Horizontal, self)
        bb.button(BB.Ok).setIcon(newIcon('done'))
        bb.button(BB.Cancel).setIcon(newIcon('undo'))
        bb.accepted.connect(self.validate)
        bb.rejected.connect(self.cancel)
        gridlayout.addWidget(bb, 4, 0, 1, 2)

        self.setLayout(gridlayout)
        
        self.show()

    def validate(self):
        self.flag_accept = True
        self.accept()

    def cancel(self):
        self.flag_accept = False
        self.reject()
    
    def getStatus(self):
        return self.flag_accept

    def getDataPartition(self):
        self.train_ratio = int(self.train_input.text())
        self.val_ratio = int(self.val_input.text())
        self.test_ratio = int(self.test_input.text())

        return self.train_ratio, self.val_ratio, self.test_ratio

    def closeEvent(self, event):
        self.flag_accept = False
        self.reject()