diff --git a/ding/utils/plot_helper.py b/ding/utils/plot_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..42a3bc0e4114efd9554538bfc2c84f5d99a9d78d --- /dev/null +++ b/ding/utils/plot_helper.py @@ -0,0 +1,27 @@ +import numpy as np +import matplotlib as mpl +import matplotlib.pyplot as plt +import seaborn as sns + + +def plot(data: list, xlabel: str, ylabel: str, title: str, pth: str = './picture.jpg'): + """ + Overview: + Draw training polyline + Interface: + data (:obj:`List[Dict]`): the data we will use to draw polylines + data[i]['step']: horizontal axis data + data[i]['value']: vertical axis data + data[i]['label']: the data label + xlabel (:obj:`str`): the x label name + ylabel (:obj:`str`): the y label name + title (:obj:`str`): the title name + """ + sns.set(style="darkgrid", font_scale=1.5) + for nowdata in data: + step, value, label = nowdata['x'], nowdata['y'], nowdata['label'] + sns.lineplot(x=step, y=value, label=label) + plt.xlabel(xlabel) + plt.ylabel(ylabel) + plt.title(title) + plt.savefig(pth) diff --git a/ding/utils/tests/test_plot.py b/ding/utils/tests/test_plot.py new file mode 100644 index 0000000000000000000000000000000000000000..99d4a00c0b68af4f110e4b838c3be4c34d9ce8f8 --- /dev/null +++ b/ding/utils/tests/test_plot.py @@ -0,0 +1,35 @@ +import random +import numpy as np +import os +import pytest + +from ding.utils.plot_helper import plot + + +@pytest.mark.unittest +def test_plot(): + rewards1 = np.array([0, 0.1, 0, 0.2, 0.4, 0.5, 0.6, 0.9, 0.9, 0.9]) + rewards2 = np.array([0, 0, 0.1, 0.4, 0.5, 0.5, 0.55, 0.8, 0.9, 1]) + rewards = np.concatenate((rewards1, rewards2)) # concatenation array + episode1 = range(len(rewards1)) + episode2 = range(len(rewards2)) + episode = np.concatenate((episode1, episode2)) + data1 = {} + data1['x'] = episode + data1['y'] = rewards + data1['label'] = 'line1' + + rewards3 = np.random.random(10) + rewards4 = np.random.random(10) + rewards = np.concatenate((rewards3, rewards4)) # concatenation array + episode3 = range(len(rewards1)) + episode4 = range(len(rewards2)) + episode = np.concatenate((episode3, episode4)) + data2 = {} + data2['x'] = episode + data2['y'] = rewards + data2['label'] = 'line2' + + data = [data1, data2] + plot(data, 'step', 'reward_rate', 'test_pic', './pic.jpg') + assert os.path.exists('./pic.jpg') diff --git a/setup.py b/setup.py index a7ebb10b6ff34a9c62cdd6192e919f25e56c6f58..20dbe024db40b9749c367639a5b195afa134689b 100755 --- a/setup.py +++ b/setup.py @@ -53,6 +53,7 @@ setup( 'easydict==1.9', 'tensorboardX>=2.1,<=2.2', 'matplotlib', # pypy incompatible + 'seaborn', 'yapf==0.29.0', 'responses~=0.12.1', 'flask~=1.1.2',