티스토리 뷰

Study/AI

[Python][Visualization] matplotlib multiple plot

생각많은 소심남 2018. 7. 3. 00:10

 보통 Data Science상에서 주로 쓰는 Python Visualization package가 여러 개 있는데, 많이 쓰는 것이 Scipy package에 들어있는 matplotlib.pyplot (보통 plt라고 alias해서 사용한다.) 거기서 주로 사용하는 기능에 대해서 간단히 다뤄보고자 한다.

 우선 첫번째로 다룰 기능 multiple plot이다. 보통 그래프를 그리게 되면 한 plot당 하나의 그래프가 출력되겠지만, 필요에 따라서는 여러 개의 plot을 하나의 grid에 출력하고 싶을 때도 있을 것이고, 혹은 하나의 plt당 여러개의 plot을 나눠서 따로따로 출력하고 싶을 수도 있다. 일단 가장 흔하게 할 수 있는 건 하나의 common 축을 가지는 그래프를 출력해보는 것은 다음과 같다.
(참고 : 예제에 쓰인 데이터는 2000년대 나스닥에 상장된 4대 IT기업의 나스닥지수이다.)

stocks.csv

  1. import matplotlib.pyplot as plt
  2. import pandas as pd
  3.  
  4. file = '../stocks.csv'
  5.  
  6. # Convert RangeIndex to DateTimeIndex
  7. stock = pd.read_csv(file, parse_dates=['Date'], index_col=['Date'])
  8.  
  9. plt.plot(stock.index, stock['AAPL'], color='red')
  10. plt.plot(stock.index, stock['MSFT'], color='blue')
  11.  
  12. plt.title('Stock price in 2000')
  13. plt.xlabel('Date')
  14. plt.ylabel('Stock price')
  15.  
  16. plt.show()

 이러면 아래와 같이 하나의 plot안에 공통의 축을 가지는 plot을 두개 그릴 수 있다.

두번째는 subplot을 활용한 plot이다. 이를 활용하면 위와 같이 한 grid에 같이 표현되는게 아니라 두개의 plot이 따로따로 출력된다. code로 표현하면 다음과 같이 된다.

  1. import matplotlib.pyplot as plt
  2. import pandas as pd
  3.  
  4. file = '../stocks.csv'
  5.  
  6. # Convert RangeIndex to DateTimeIndex
  7. stock = pd.read_csv(file, parse_dates=['Date'], index_col=['Date'])
  8.  
  9. plt.subplot(2,1,1)
  10. plt.plot(stock.index, stock['AAPL'], color='red')
  11. plt.xlabel('Date')
  12. plt.ylabel('Stock price')
  13. plt.title('AAPL Stock price in 2000')
  14.  
  15. plt.subplot(2,1,2)
  16. plt.plot(stock.index, stock['MSFT'], color='blue')
  17. plt.title('MSFT Stock price in 2000')
  18. plt.xlabel('Date')
  19. plt.ylabel('Stock price')
  20.  
  21. plt.show()

우선 subplot을 사용했는데, 인자로 들어가는 값이 다음과 같다.

plt.subplot(nrows, ncols, index)
- nrows : plot의 row 갯수
- ncols : plot의 column 갯수
- index : plot의 index

그래서 위와 같이 표현하면 row가 2, column이 1인 plot을 두개 그리되, 첫번째 plot을 위에, 두번째 plot을 아래에 배치하겠다는 것이다. 이렇게 하면 딱 두개의 그래프가 출력된다. 

그런데 그래프가 보면 위 그래프의 xlabel과 아래그래프의 title이 겹쳐지는 것을 볼수가 있다. 물론 그래프의 배치를 위와 같이 겹치지 않게 row는 1, column 2인 subplot을 그려도 좋겠지만, 이 상태에서도 글자가 겹치지 않게 하는 방법이 있다. 

 

  1. import matplotlib.pyplot as plt
  2. import pandas as pd
  3.  
  4. file = '../stocks.csv'
  5.  
  6. # Convert RangeIndex to DateTimeIndex
  7. stock = pd.read_csv(file, parse_dates=['Date'], index_col=['Date'])
  8.  
  9. plt.subplot(2,1,1)
  10. plt.plot(stock.index, stock['AAPL'], color='red')
  11. plt.xlabel('Date')
  12. plt.ylabel('Stock price')
  13. plt.title('AAPL Stock price in 2000')
  14.  
  15. plt.subplot(2,1,2)
  16. plt.plot(stock.index, stock['MSFT'], color='blue')
  17. plt.title('MSFT Stock price in 2000')
  18. plt.xlabel('Date')
  19. plt.ylabel('Stock price')
  20.  
  21. plt.tight_layout()
  22. plt.show()

 차이점을 발견했겠지만 21번째 line에 있는 tight_layout을 사용하게 되면 plt를 그리는 각 객체가 overlay되지 않고 이쁘게 그려진다. 

 유의할 점이라면 subplot의 plot이 되는 기준은 top-left부터 row indexing되는 방향으로 그려지며, 일반적으로 python의 index가 0부터 시작하는데 비해서 subplot index는 1부터 시작한다는 점이다.

댓글