utils.py 3.3 KB
Newer Older
1 2 3 4
# Module: Utility
# Author: Moez Ali <moez.ali@queensu.ca>
# License: MIT

P
PyCaret 已提交
5
version_ = "pycaret-nightly-0.37"
P
PyCaret 已提交
6

7
def version():
P
PyCaret 已提交
8
    print(version_)
9

P
PyCaret 已提交
10 11
def __version__():
    return version_
12 13 14 15 16 17 18 19 20 21 22 23

def check_metric(actual, prediction, metric, round=4):
    
    """
    reserved for docstring
    """
    
    #general dependencies
    import numpy as np

    #metric calculation starts here
    
P
PyCaret 已提交
24
    if metric == 'Accuracy':
25 26 27 28 29
        
        from sklearn import metrics
        result = metrics.accuracy_score(actual,prediction)
        result = result.round(round)
        
P
PyCaret 已提交
30
    elif metric == 'Recall':
31 32 33 34 35
        
        from sklearn import metrics
        result = metrics.recall_score(actual,prediction)
        result = result.round(round)
        
P
PyCaret 已提交
36
    elif metric == 'Precision':
37 38 39 40 41
        
        from sklearn import metrics
        result = metrics.precision_score(actual,prediction)
        result = result.round(round)
        
P
PyCaret 已提交
42
    elif metric == 'F1':
43 44 45 46 47
        
        from sklearn import metrics
        result = metrics.f1_score(actual,prediction)
        result = result.round(round)
        
P
PyCaret 已提交
48
    elif metric == 'Kappa':
49 50 51 52 53
        
        from sklearn import metrics
        result = metrics.cohen_kappa_score(actual,prediction)
        result = result.round(round)
       
P
PyCaret 已提交
54
    elif metric == 'AUC':
55 56 57 58 59
        
        from sklearn import metrics
        result = metrics.roc_auc_score(actual,prediction)
        result = result.round(round)
        
P
PyCaret 已提交
60 61 62 63 64 65 66
    elif metric == 'MCC':
        
        from sklearn import metrics
        result = metrics.matthews_corrcoef(actual,prediction)
        result = result.round(round)

    elif metric == 'MAE':
67 68 69 70 71

        from sklearn import metrics
        result = metrics.mean_absolute_error(actual,prediction)
        result = result.round(round)
        
P
PyCaret 已提交
72
    elif metric == 'MSE':
73 74 75 76 77

        from sklearn import metrics
        result = metrics.mean_squared_error(actual,prediction)
        result = result.round(round)        
        
P
PyCaret 已提交
78
    elif metric == 'RMSE':
79 80 81 82 83 84

        from sklearn import metrics
        result = metrics.mean_squared_error(actual,prediction)
        result = np.sqrt(result)
        result = result.round(round)     
        
P
PyCaret 已提交
85
    elif metric == 'R2':
86 87 88 89 90

        from sklearn import metrics
        result = metrics.r2_score(actual,prediction)
        result = result.round(round)    
        
P
PyCaret 已提交
91
    elif metric == 'RMSLE':
92 93 94 95

        result = np.sqrt(np.mean(np.power(np.log(np.array(abs(prediction))+1) - np.log(np.array(abs(actual))+1), 2)))
        result = result.round(round)

P
PyCaret 已提交
96
    elif metric == 'MAPE':
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127

        mask = actual != 0
        result = (np.fabs(actual - prediction)/actual)[mask].mean()
        result = result.round(round)
       
    return result


def enable_colab():
    
    """
    Function to render plotly visuals in colab.
    """
    
    def configure_plotly_browser_state():
        
        import IPython
        display(IPython.core.display.HTML('''
            <script src="/static/components/requirejs/require.js"></script>
            <script>
              requirejs.config({
                paths: {
                  base: '/static/base',
                  plotly: 'https://cdn.plot.ly/plotly-latest.min.js?noext',
                },
              });
            </script>
            '''))
  
    import IPython
    IPython.get_ipython().events.register('pre_run_cell', configure_plotly_browser_state)
P
PyCaret 已提交
128
    print('Colab mode activated.')