seabornを利用して簡単にヒートマップを作成する方法を紹介します。
ついでに、全columnに対する相関係数の作成方法も扱います。
データセットとしては、Bike Sharing Data Setを利用します。
UCI Bike Sharing Data set
ポイント
今回のポイントは2つです。
一つ目は、Dataframeから相関係数のDataframeを作成する方法で、非常に単純です。
相互に相関係数を取りたいdataframeを作成し、dataframeのcorr()メソッドを使うだけです。
# df_for_corr: 相関係数を導出する対象のdataframe
corr_df = df_for_corr.corr()
これでcorr_dfという相関係数のdataframeが作成されます。
中身はコードの詳細のところで示しますが、正方行列になっています。
二つ目は、ヒートマップの作成部分です。こちらも描画自体はシンプルに1行で書くことができます。
# すでにimport seaborn as snsでseabornが使える状態であるとする。より使いやすい版を後述。
fig, ax = plt.subplots(figsize=(8,8)) #square=Trueを入れると各グリッドが正方形になる
sns.heatmap(data=corr_df, cmap="RdBu_r", annot=True, fmt=".2f", vmax=1, vmin=-1, square=True)
描画結果がこちら。
annot=Trueにして値を表示させていますが、annot=Falseにした方が汎用性は高いです。
また、引数にlinewidthとlinecolorを入れると見やすくなる場合があります。
# 実用版1
fig, ax = plt.subplots(figsize=(8,8))
sns.heatmap(data=corr_df, cmap="RdBu_r", annot=False, vmax=1, vmin=-1, linecolor="white", linewidths=.5, square=True)
その他、y軸の順序を反転させたい場合や、横のカラーバーを表示したくない場合は下記の例が参考になるかと思います。
# 実用版2。カラーバーを非表示にした分、画像自体が大きくなることに注意。
fig, ax = plt.subplots(figsize=(8,8))
sns.heatmap(data=corr_df, cmap="RdBu_r", vmax=1, vmin=-1, square=True, linecolor="white", linewidth=1, cbar=False)
このコードは、ヒートマップの描画対象が相関係数だと想定しての設定です。
引数の設定が不要であれば、最悪dataだけ準備すればOKです。
色なども含めた細かな設定は、下記のドキュメントや詳しく記載してくださっているブログ等をご参照ください。
seaborn.heatmap
seaborn Choosing color palettes
Pythonでデータサイエンス Seaborn でヒートマップを作成する
また、seabornを初めて使う方は、snsという略称を疑問に思うかもしれませんが、seabornをimportする時に慣例的に使用されている名称なので、あまり気にしないでも大丈夫です。
コード全体
一番初めのグラフを例に、コードの全体像を確認しましょう。
なお、下記のコードはJupyter Notebook上で使用することを想定しています。
まずはライブラリとデータをimportします。
一部使用していないライブラリも含みますが、ご了承ください。
まずは相関行列を作成するところまでです。
# ライブラリのimport
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import os
import sys
import seaborn as sns
from IPython.display import display
# データのimport。ファイルパスは必要に応じて変更。
bike_data = pd.read_csv("../input/Bike-Sharing-Dataset/day.csv")
# 今回はすべてのデータに対する相関係数ではなく、列を限定。
col_for_corr = ["temp", "atemp", "hum", "windspeed", "casual", "registered", "cnt"]
df_for_corr = bike_data[col_for_corr]
# 相関係数を取るdataframeが準備できたので、corr()で相関行列を作成。
corr_df = df_for_corr.corr()
corr_dfの中身を確認するとこのようになります。
corr_df
temp | atemp | hum | windspeed | casual | registered | cnt | |
---|---|---|---|---|---|---|---|
temp | 1.000000 | 0.991702 | 0.126963 | -0.157944 | 0.543285 | 0.540012 | 0.627494 |
atemp | 0.991702 | 1.000000 | 0.139988 | -0.183643 | 0.543864 | 0.544192 | 0.631066 |
hum | 0.126963 | 0.139988 | 1.000000 | -0.248489 | -0.077008 | -0.091089 | -0.100659 |
windspeed | -0.157944 | -0.183643 | -0.248489 | 1.000000 | -0.167613 | -0.217449 | -0.234545 |
casual | 0.543285 | 0.543864 | -0.077008 | -0.167613 | 1.000000 | 0.395282 | 0.672804 |
registered | 0.540012 | 0.544192 | -0.091089 | -0.217449 | 0.395282 | 1.000000 | 0.945517 |
cnt | 0.627494 | 0.631066 | -0.100659 | -0.234545 | 0.672804 | 0.945517 | 1.000000 |
それでは、ヒートマップを作成します。
fig, ax = plt.subplots(figsize=(8,8))
# ヒートマップの描画
sns.heatmap(data=corr_df, cmap="RdBu_r", annot=True, fmt=".2f", vmax=1, vmin=-1, square=True)
# タイトル、ラベルの設定。
ax.set_title("correlation")
ax.set_xticklabels(ax.get_xmajorticklabels(), rotation=0)
ax.set_yticklabels(ax.get_ymajorticklabels(), rotation=0)
plt.savefig("../output/heatmap.png", bbox_inches="tight", dpi=120)
# 画像として保存
plt.savefig("../output/heatmap.png", bbox_inches="tight", dpi=120)
ラベルに関する記述をつけたのは、長いラベルだと入りきらない場合があるためです。
色などの設定を変えたい場合は、ドキュメント等をご参照ください。
まとめ
Pythonでの簡単なヒートマップの作成方法でした。
基本的な使い方がわかればいくらでも応用はできると思いますので、適宜調べながら使ってみてください。
参考文献
Bike Sharing Dataset
Fanaee-T, Hadi, and Gama, Joao, ‘Event labeling combining ensemble detectors and background knowledge’, Progress in Artificial Intelligence (2013): pp. 1-15, Springer Berlin Heidelberg,
seaborn.heatmap
seaborn Choosing color palettes
Pythonでデータサイエンス