3次元のScatter Plot

主成分解析やt-SNEなど、3次元まで圧縮したデータをプロットするときには、マウスなどで「回せる」のが良い。以下、その方法を解析する。まずは例としてMNISTのPCAのデータを以下のように用意する。

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import sklearn

def load_csv(csv):
    xx=np.array(pd.read_csv(csv))
    x_data=xx[:,1:].astype('float32') / 255
    y_data=xx[:,0]
    
    return x_data, y_data

x_train, y_train=load_csv('./mnist_train.csv')
x_test, y_test=load_csv('./mnist_test.csv')

n=int(x_train.shape[0]/10) #データの個数(x_train.shape[0])を10で割って整数にしている
x_train2=x_train[:n,]
y_train2=y_train[:n]

n=int(x_test.shape[0]/10) 
x_test2=x_test[:n,]
y_test2=y_test[:n]

from sklearn.decomposition import PCA

pca=PCA(n_components=8)
transformed=pca.fit_transform(x_train2)
In [2]:
cmap='tab10'
plt.scatter(transformed[:,0],transformed[:,1],s=3,c=y_train2,cmap=cmap)
plt.colorbar()
plt.xlabel("PC1")
plt.ylabel("PC2")
Out[2]:
Text(0, 0.5, 'PC2')

こんな感じでtransformedという配列に、主成分スコアが入っているとする。

さて、上のデータを3次元でプロットするために、cufflinksというライブラリを用いる。pipでインストールしているはずなので(していない人は、「pip install cufflinks」を実行)、下のようにインポートすれば良いはずである。

In [3]:
import cufflinks as cf
cf.go_offline() #オフラインで使うという指示

cufflinksを用いてプロットするためには、入力をデータフレームという形式に変換する必要がある。データフレームとは、pandasのデータ構造の一種で、2次元のデータを扱い易くした形式であるが、今はあまり詳細に立ち入らない。

以下のようにして、PCAのスコアが入っているtransformedという配列をデータフレームに変換する。

In [4]:
df=pd.DataFrame(transformed[0:6000,0:3],columns=list("XYZ"))

ここで、単に「df」とセルに入力して実行すると、以下のような出力が見られるはずである。

In [5]:
df
Out[5]:
X Y Z
0 0.532408 1.373723 0.136337
1 4.249131 1.259853 1.993629
2 -0.284628 -1.745703 -1.154596
3 -3.434115 2.445623 0.644639
4 -1.495364 -2.667722 0.040201
... ... ... ...
5995 1.867286 -3.858886 0.147401
5996 -0.231575 -0.635003 -0.284178
5997 0.039256 2.652331 2.936591
5998 1.149301 0.204864 1.305024
5999 0.817908 -3.176972 1.883256

6000 rows × 3 columns

カラムの名前をX, Y, Zと付けている。次に、ラベル情報をこのデータフレームに加える。後から使うプロットの関数の都合で、このラベル情報は文字列である必要がある。なので、以下のようにしてデータフレームdfに「label」のカラムを加える。

In [6]:
df["label"]=np.array(y_train2,dtype=str)

そうすると、「label」というカラムが加わっているはずです。

In [7]:
df
Out[7]:
X Y Z label
0 0.532408 1.373723 0.136337 5
1 4.249131 1.259853 1.993629 0
2 -0.284628 -1.745703 -1.154596 4
3 -3.434115 2.445623 0.644639 1
4 -1.495364 -2.667722 0.040201 9
... ... ... ... ...
5995 1.867286 -3.858886 0.147401 7
5996 -0.231575 -0.635003 -0.284178 6
5997 0.039256 2.652331 2.936591 8
5998 1.149301 0.204864 1.305024 6
5999 0.817908 -3.176972 1.883256 9

6000 rows × 4 columns

さて、このデータフレームを用いて、3次元のプロットを作成する。以下のコマンドで3次元のプロットが出力されるはずである。入力のcategoriesでラベル情報を入れている。sizeは点の大きさをしていている。マウスによる操作で回転、拡大/縮小などが出来るはずである。

In [8]:
df.iplot(x="X",y="Y",z="Z",kind="scatter3d", categories="label", size=4, xTitle="PC1", yTitle="PC2", zTitle="PC3")

色の順番がおかしいのは、データファイルの中で出てくる順番で色が割り振られるからである。気になるのであればデータをラベルでソートすれば良い。6000個のデータ点があるが、少しデータ点が多すぎると感じたら、以下のようにしてデータ数を(例えば1000まで)減らしてプロットしてみると良い。

In [9]:
df2=df[0:1000]
In [10]:
df2.iplot(x="X",y="Y",z="Z",kind="scatter3d", categories="label", size=5, xTitle="PC1", yTitle="PC2", zTitle="PC3")
In [ ]: