plot_curve.py 1.3 KB
Newer Older
Q
qiaolongfei 已提交
1
from IPython import display
Q
qiaolongfei 已提交
2
import os
Q
qiaolongfei 已提交
3 4 5 6 7 8


class PlotCost(object):
    """
    append train and test cost in event_handle and then call plot.
    """
Q
qiaolongfei 已提交
9

Q
qiaolongfei 已提交
10 11 12 13
    def __init__(self):
        self.train_costs = ([], [])
        self.test_costs = ([], [])

Q
qiaolongfei 已提交
14 15 16 17 18 19 20 21
        self.__disable_plot__ = os.environ.get("DISABLE_PLOT")
        if not self.__plot_is_disabled__():
            import matplotlib.pyplot as plt
            self.plt = plt

    def __plot_is_disabled__(self):
        return self.__disable_plot__ == "True"

Q
qiaolongfei 已提交
22
    def plot(self):
Q
qiaolongfei 已提交
23 24 25 26 27
        if self.__plot_is_disabled__():
            return

        self.plt.plot(*self.train_costs)
        self.plt.plot(*self.test_costs)
Q
qiaolongfei 已提交
28 29 30 31 32
        title = []
        if len(self.train_costs[0]) > 0:
            title.append('Train Cost')
        if len(self.test_costs[0]) > 0:
            title.append('Test Cost')
Q
qiaolongfei 已提交
33
        self.plt.legend(title, loc='upper left')
Q
qiaolongfei 已提交
34
        display.clear_output(wait=True)
Q
qiaolongfei 已提交
35 36
        display.display(self.plt.gcf())
        self.plt.gcf().clear()
Q
qiaolongfei 已提交
37 38 39 40 41 42 43 44 45 46 47 48

    def append_train_cost(self, step, cost):
        self.train_costs[0].append(step)
        self.train_costs[1].append(cost)

    def append_test_cost(self, step, cost):
        self.test_costs[0].append(step)
        self.test_costs[1].append(cost)

    def reset(self):
        self.train_costs = ([], [])
        self.test_costs = ([], [])