Aggiungere una legenda a PyPlot in Matplotlib nel modo più semplice possibile

TL;DR -> Come si può creare una legenda per un grafico a linee in Matplotlib's PyPlot senza creare alcuna variabile extra?

Considerate lo script del grafico qui sotto:

if __name__ == '__main__':
    PyPlot.plot(total_lengths, sort_times_bubble, 'b-',
                total_lengths, sort_times_ins, 'r-',
                total_lengths, sort_times_merge_r, 'g+',
                total_lengths, sort_times_merge_i, 'p-', )
    PyPlot.title("Combined Statistics")
    PyPlot.xlabel("Length of list (number)")
    PyPlot.ylabel("Time taken (seconds)")
    PyPlot.show()

Come potete vedere, questo è un uso molto semplice di matplotlib's PyPlot. Questo genera idealmente un grafico come quello qui sotto:

Niente di speciale, lo so. Tuttavia, non è chiaro quali dati vengono tracciati dove (sto cercando di tracciare i dati di alcuni algoritmi di ordinamento, la lunghezza contro il tempo impiegato, e vorrei assicurarmi che la gente sappia quale linea è quale). Quindi, ho bisogno di una legenda, tuttavia, dando un'occhiata al seguente esempio qui sotto (dal sito ufficiale):

ax = subplot(1,1,1)
p1, = ax.plot([1,2,3], label="line 1")
p2, = ax.plot([3,2,1], label="line 2")
p3, = ax.plot([2,3,1], label="line 3")

handles, labels = ax.get_legend_handles_labels()

# reverse the order
ax.legend(handles[::-1], labels[::-1])

# or sort them by labels
import operator
hl = sorted(zip(handles, labels),
            key=operator.itemgetter(1))
handles2, labels2 = zip(*hl)

ax.legend(handles2, labels2)

Vedrete che ho bisogno di creare una variabile extra ax. Come posso aggiungere una legenda al mio grafico senza dover creare questa variabile extra e mantenendo la semplicità del mio script attuale.

Soluzione

Aggiungete una label= a ciascuna delle vostre chiamate plot(), e poi chiamate legend(loc='upper left').

Considerate questo esempio (testato con Python 3.8.0):

import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(0, 20, 1000)
y1 = np.sin(x)
y2 = np.cos(x)

plt.plot(x, y1, "-b", label="sine")
plt.plot(x, y2, "-r", label="cosine")
plt.legend(loc="upper left")
plt.ylim(-1.5, 2.0)
plt.show()

Leggermente modificato da questo tutorial: http://jakevdp.github.io/mpl_tutorial/tutorial_pages/tut1.html

Commentari (5)

Ecco un esempio per aiutarvi...

fig = plt.figure(figsize=(10,5))
ax = fig.add_subplot(111)
ax.set_title('ADR vs Rating (CS:GO)')
ax.scatter(x=data[:,0],y=data[:,1],label='Data')
plt.plot(data[:,0], m*data[:,0] + b,color='red',label='Our Fitting 
Line')
ax.set_xlabel('ADR')
ax.set_ylabel('Rating')
ax.legend(loc='best')
plt.show()

Commentari (0)

Aggiungete delle etichette ad ogni argomento nella vostra chiamata plot corrispondente alla serie che sta graficando, cioè label = "serie 1"

Poi aggiungete semplicemente Pyplot.legend() alla fine del vostro script e la leggenda mostrerà queste etichette.

Commentari (1)