未验证 提交 48d6c826 编写于 作者: W Weiyuhong-1998 提交者: GitHub

fix(wyh): add plot function (#59)

* fix(wyh): plot function

* fix(wyh): plot function pytest

* fix(wyh):plot function modify comments

* feature(wyh):plot style
Co-authored-by: Nweiyuhong <weiyuhong@sensetime.com>
上级 07ceba40
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
def plot(data: list, xlabel: str, ylabel: str, title: str, pth: str = './picture.jpg'):
"""
Overview:
Draw training polyline
Interface:
data (:obj:`List[Dict]`): the data we will use to draw polylines
data[i]['step']: horizontal axis data
data[i]['value']: vertical axis data
data[i]['label']: the data label
xlabel (:obj:`str`): the x label name
ylabel (:obj:`str`): the y label name
title (:obj:`str`): the title name
"""
sns.set(style="darkgrid", font_scale=1.5)
for nowdata in data:
step, value, label = nowdata['x'], nowdata['y'], nowdata['label']
sns.lineplot(x=step, y=value, label=label)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
plt.savefig(pth)
import random
import numpy as np
import os
import pytest
from ding.utils.plot_helper import plot
@pytest.mark.unittest
def test_plot():
rewards1 = np.array([0, 0.1, 0, 0.2, 0.4, 0.5, 0.6, 0.9, 0.9, 0.9])
rewards2 = np.array([0, 0, 0.1, 0.4, 0.5, 0.5, 0.55, 0.8, 0.9, 1])
rewards = np.concatenate((rewards1, rewards2)) # concatenation array
episode1 = range(len(rewards1))
episode2 = range(len(rewards2))
episode = np.concatenate((episode1, episode2))
data1 = {}
data1['x'] = episode
data1['y'] = rewards
data1['label'] = 'line1'
rewards3 = np.random.random(10)
rewards4 = np.random.random(10)
rewards = np.concatenate((rewards3, rewards4)) # concatenation array
episode3 = range(len(rewards1))
episode4 = range(len(rewards2))
episode = np.concatenate((episode3, episode4))
data2 = {}
data2['x'] = episode
data2['y'] = rewards
data2['label'] = 'line2'
data = [data1, data2]
plot(data, 'step', 'reward_rate', 'test_pic', './pic.jpg')
assert os.path.exists('./pic.jpg')
......@@ -53,6 +53,7 @@ setup(
'easydict==1.9',
'tensorboardX>=2.1,<=2.2',
'matplotlib', # pypy incompatible
'seaborn',
'yapf==0.29.0',
'responses~=0.12.1',
'flask~=1.1.2',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册