diff --git a/benchmark/paddle/image/plotlog.py b/benchmark/paddle/image/plotlog.py index ce9d6ac24ada46470c941fd88164f27eea7f483d..34043c49948a69b18e4b8e69997dc61b30dc0bd4 100644 --- a/benchmark/paddle/image/plotlog.py +++ b/benchmark/paddle/image/plotlog.py @@ -70,12 +70,19 @@ def sample(metric, sample_rate): return metric_sample -def plot_metric(metric, batch_id, graph_title): +def plot_metric(metric, batch_id, graph_title, line_style='b-', + line_label='y', + line_num=1): plt.figure() plt.title(graph_title) - plt.plot(batch_id, metric) + if line_num == 1: + plt.plot(batch_id, metric, line_style, line_label) + else: + for i in line_num: + plt.plot(batch_id, metric[i], line_style[i], line_label[i]) plt.xlabel('batch') plt.ylabel(graph_title) + plt.legend() plt.savefig(graph_title + '.jpg') plt.close() @@ -91,8 +98,8 @@ def main(): loss_sample = sample(loss, args.sample_rate) accuracy_sample = sample(accuracy, args.sample_rate) - plot_metric(loss_sample, batch_sample, 'loss') - plot_metric(accuracy_sample, batch_sample, 'accuracy') + plot_metric(loss_sample, batch_sample, 'loss', line_label='loss') + plot_metric(accuracy_sample, batch_sample, 'accuracy', line_style='g-', line_label='accuracy') if __name__ == '__main__':