CNNを変換に用いたVAE

まずは、ライブラリをインポートしデータなどを読み込む。この部分はlecture.htmlと同一である。

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import sklearn
import time

cmap='tab10' #グラフをプロットしたときのカラーコードの指定

def load_csv(csv):
    xx=np.array(pd.read_csv(csv))
    x_data=xx[:,1:].astype('float32') / 255
    y_data=xx[:,0]
    
    return x_data, y_data

x_train, y_train=load_csv('./mnist_train.csv')
x_test, y_test=load_csv('./mnist_test.csv')

n=int(x_train.shape[0]/10) #データの個数(x_train.shape[0])を10で割って整数にしている
x_train2=x_train[:n,]
y_train2=y_train[:n]

n=int(x_test.shape[0]/10) 
x_test2=x_test[:n,]
y_test2=y_test[:n]

import keras

from keras.layers import Lambda, Input, Dense, Dropout
from keras.models import Model
from keras.losses import mse, binary_crossentropy
from keras.models import Sequential
from keras.layers import MaxPooling2D

from keras import layers
from keras.layers import Conv2D, Flatten
from keras.layers import Reshape, Conv2DTranspose

from keras import backend as K

pixel_size=28
Using TensorFlow backend.

以下の関数vae_cnnがCNNを変換に用いたVAEを実行する関数となる。使い方は、関数vae_mlpと同じである。

In [3]:
def vae_cnn(x_train, x_test, latent_dim = 2, epochs = 10, pixel_size=pixel_size):
    x_train = x_train.reshape((x_train.shape[0],pixel_size,pixel_size,1))
    x_test = x_test.reshape((x_test.shape[0],pixel_size,pixel_size,1))
    # network parameters
    input_shape = (pixel_size, pixel_size, 1)
    batch_size = 128
    kernel_size = 3
    filters = 16
    # VAE model = encoder + decoder
    # build encoder model
    inputs = Input(shape=input_shape, name='encoder_input')
    
    x = inputs
    x = Conv2D(filters=32,
                kernel_size=kernel_size,
                activation='relu',
                strides=2,
                padding='same')(x)
    x = Conv2D(filters=64,
                kernel_size=kernel_size,
                activation='relu',
                strides=2,
                padding='same')(x)
    # shape info needed to build decoder model
    shape = K.int_shape(x)
    # generate latent vector Q(z|X)
    x = Flatten()(x)
    x = Dense(16, activation='relu')(x)
    z_mean = Dense(latent_dim, name='z_mean')(x)
    z_log_var = Dense(latent_dim, name='z_log_var')(x)
    # use reparameterization trick to push the sampling out as input
    # note that "output_shape" isn't necessary with the TensorFlow backend
    z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])
    # instantiate encoder model
    encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
    encoder.summary()
    
    #plot_model(encoder, to_file='vae_cnn_encoder.png', show_shapes=True)
    # build decoder model
    latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
    
    x = Dense(shape[1] * shape[2] * shape[3], activation='relu')(latent_inputs)
    x = Reshape((shape[1], shape[2], shape[3]))(x)
    x = Conv2DTranspose(filters=64,
                            kernel_size=kernel_size,
                            activation='relu',
                            strides=2,
                            padding='same')(x)
    x = Conv2DTranspose(filters=32,
                            kernel_size=kernel_size,
                            activation='relu',
                            strides=2,
                            padding='same')(x)
    
    outputs = Conv2DTranspose(filters=1,
                              kernel_size=kernel_size,
                              activation='sigmoid',
                              padding='same',
                              name='decoder_output')(x)
    # instantiate decoder model
    decoder = Model(latent_inputs, outputs, name='decoder')
    decoder.summary()
    #plot_model(decoder, to_file='vae_cnn_decoder.png', show_shapes=True)
    # instantiate VAE model
    outputs = decoder(encoder(inputs)[2])
    vae = Model(inputs, outputs, name='vae')
    #models = (encoder, decoder)
    #data = (x_test, y_test)
    # VAE loss = mse_loss or xent_loss + kl_loss
    reconstruction_loss = binary_crossentropy(K.flatten(inputs),
                                            K.flatten(outputs))
    reconstruction_loss *= pixel_size * pixel_size
    kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
    kl_loss = K.sum(kl_loss, axis=-1)
    kl_loss *= -0.5
    vae_loss = K.mean(reconstruction_loss + kl_loss)
    vae.add_loss(vae_loss)
    vae.compile(optimizer='rmsprop')
    vae.summary()
    # train the autoencoder
    result = vae.fit(x_train,
                     epochs=epochs,
                     batch_size=batch_size,
                     validation_data=(x_test, None))
    return encoder, decoder, result

def sampling(args):
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

この関数を実行するには以下のようにする。下のセルでは、x_train2と削減したデータセットを入力しているが、全部のデータを使う場合はそれに対応した入力にすれば良い。

In [4]:
encoder, decoder, result= vae_cnn(x_train2, x_test2, latent_dim = 2, epochs = 30, pixel_size=pixel_size)
Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
encoder_input (InputLayer)      (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 14, 14, 32)   320         encoder_input[0][0]              
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 7, 7, 64)     18496       conv2d_1[0][0]                   
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 3136)         0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 16)           50192       flatten_1[0][0]                  
__________________________________________________________________________________________________
z_mean (Dense)                  (None, 2)            34          dense_1[0][0]                    
__________________________________________________________________________________________________
z_log_var (Dense)               (None, 2)            34          dense_1[0][0]                    
__________________________________________________________________________________________________
z (Lambda)                      (None, 2)            0           z_mean[0][0]                     
                                                                 z_log_var[0][0]                  
==================================================================================================
Total params: 69,076
Trainable params: 69,076
Non-trainable params: 0
__________________________________________________________________________________________________
Model: "decoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
z_sampling (InputLayer)      (None, 2)                 0         
_________________________________________________________________
dense_2 (Dense)              (None, 3136)              9408      
_________________________________________________________________
reshape_1 (Reshape)          (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 14, 14, 64)        36928     
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 28, 28, 32)        18464     
_________________________________________________________________
decoder_output (Conv2DTransp (None, 28, 28, 1)         289       
=================================================================
Total params: 65,089
Trainable params: 65,089
Non-trainable params: 0
_________________________________________________________________
C:\Users\chkar\Anaconda3\envs\VAE2019_2\lib\site-packages\keras\engine\training_utils.py:819: UserWarning: Output decoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to decoder.
  'be expecting any data to be passed to {0}.'.format(name))
Model: "vae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
encoder_input (InputLayer)   (None, 28, 28, 1)         0         
_________________________________________________________________
encoder (Model)              [(None, 2), (None, 2), (N 69076     
_________________________________________________________________
decoder (Model)              (None, 28, 28, 1)         65089     
=================================================================
Total params: 134,165
Trainable params: 134,165
Non-trainable params: 0
_________________________________________________________________
Train on 6000 samples, validate on 1000 samples
Epoch 1/30
6000/6000 [==============================] - 5s 820us/step - loss: 292.2724 - val_loss: 209.9396
Epoch 2/30
6000/6000 [==============================] - 4s 744us/step - loss: 210.9115 - val_loss: 201.6409
Epoch 3/30
6000/6000 [==============================] - 5s 751us/step - loss: 201.9384 - val_loss: 192.8190
Epoch 4/30
6000/6000 [==============================] - 5s 815us/step - loss: 197.2687 - val_loss: 189.4416
Epoch 5/30
6000/6000 [==============================] - 4s 705us/step - loss: 190.0335 - val_loss: 181.5865
Epoch 6/30
6000/6000 [==============================] - 4s 724us/step - loss: 181.3990 - val_loss: 176.7970
Epoch 7/30
6000/6000 [==============================] - 5s 833us/step - loss: 178.9067 - val_loss: 173.3379
Epoch 8/30
6000/6000 [==============================] - 5s 758us/step - loss: 176.8555 - val_loss: 173.0774
Epoch 9/30
6000/6000 [==============================] - 5s 780us/step - loss: 175.3090 - val_loss: 170.9065
Epoch 10/30
6000/6000 [==============================] - 5s 829us/step - loss: 174.3787 - val_loss: 170.3192
Epoch 11/30
6000/6000 [==============================] - 5s 798us/step - loss: 173.0755 - val_loss: 175.3299
Epoch 12/30
6000/6000 [==============================] - 6s 918us/step - loss: 172.2265 - val_loss: 170.5376
Epoch 13/30
6000/6000 [==============================] - 5s 805us/step - loss: 171.3049 - val_loss: 169.6315
Epoch 14/30
6000/6000 [==============================] - 5s 778us/step - loss: 170.5933 - val_loss: 167.6497
Epoch 15/30
6000/6000 [==============================] - 5s 813us/step - loss: 170.0217 - val_loss: 167.1276
Epoch 16/30
6000/6000 [==============================] - 5s 865us/step - loss: 169.0268 - val_loss: 168.3431
Epoch 17/30
6000/6000 [==============================] - 5s 817us/step - loss: 168.2137 - val_loss: 166.5651
Epoch 18/30
6000/6000 [==============================] - 5s 763us/step - loss: 168.0051 - val_loss: 166.7262
Epoch 19/30
6000/6000 [==============================] - 5s 754us/step - loss: 167.1926 - val_loss: 169.1962
Epoch 20/30
6000/6000 [==============================] - 5s 768us/step - loss: 166.7274 - val_loss: 172.7280
Epoch 21/30
6000/6000 [==============================] - 5s 754us/step - loss: 166.4733 - val_loss: 164.1067
Epoch 22/30
6000/6000 [==============================] - 5s 797us/step - loss: 166.0091 - val_loss: 164.6365
Epoch 23/30
6000/6000 [==============================] - 5s 911us/step - loss: 165.3659 - val_loss: 164.2008
Epoch 24/30
6000/6000 [==============================] - 6s 924us/step - loss: 165.2271 - val_loss: 168.0800
Epoch 25/30
6000/6000 [==============================] - 5s 781us/step - loss: 164.5326 - val_loss: 164.6051
Epoch 26/30
6000/6000 [==============================] - 5s 760us/step - loss: 164.1372 - val_loss: 169.6967
Epoch 27/30
6000/6000 [==============================] - 5s 779us/step - loss: 164.0102 - val_loss: 163.7089
Epoch 28/30
6000/6000 [==============================] - 5s 783us/step - loss: 163.9682 - val_loss: 163.9582
Epoch 29/30
6000/6000 [==============================] - 5s 777us/step - loss: 163.2502 - val_loss: 164.9459
Epoch 30/30
6000/6000 [==============================] - 5s 794us/step - loss: 163.1323 - val_loss: 163.8535

lecture.htmlのときと同様に、lossのプロットは以下のようにすると出来る。

In [5]:
plt.plot(result.history['loss'],label='loss',color='r')
plt.plot(result.history['val_loss'],label='val_loss',color='b')
plt.legend()
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()

次に、潜在空間を描くが、z_meanを求めるためには、入力の形式を28x28とcnnの入力に合わせる必要があるため、encoder.predictへの入力を以下のようにする)。

In [6]:
z_mean, _, _ = encoder.predict(x_train2.reshape((x_train2.shape[0],pixel_size,pixel_size,1)))

plt.figure(figsize=(6, 5))
plt.scatter(z_mean[:, 0], z_mean[:, 1], s=3, c=y_train2, cmap=cmap)
plt.colorbar()
Out[6]:
<matplotlib.colorbar.Colorbar at 0x1bdbb606a88>

さらに、潜在空間でのzの値を変化させたときに、どのような形が出力されるかを描くには、以下の関数を使えばよい。

In [7]:
def plot_latent(decoder):
    
    n=30
    figure = np.zeros((pixel_size * n, pixel_size * n))
 
    grid_x = np.linspace(-4, 4, n)
    grid_y = np.linspace(-4, 4, n)[::-1]    
  

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = decoder.predict(z_sample)
            digit = x_decoded[0].reshape(pixel_size, pixel_size)
            figure[i * pixel_size: (i + 1) * pixel_size,
                   j * pixel_size: (j + 1) * pixel_size] = digit

    plt.figure(figsize=(10, 10))
    start_range = pixel_size // 2
    end_range = (n - 1) * pixel_size + start_range + 1
    pixel_range = np.arange(start_range, end_range, pixel_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap='Greys_r')
    plt.show()
In [8]:
plot_latent(decoder)

畳み込みを用いないVAEと用いるVAEの差は、1/10に削減したMNISTデータではクリアに出ないかもしれない。削減前のデータで、loss_valの下がり方や、潜在空間の再構成の様相を比較してみよ。ちょっと時間がかかるかもしれないが。

以下に全データを用いた結果を示す。

In [9]:
t1 = time.time() 
encoder, decoder, result= vae_cnn(x_train, x_test, latent_dim = 2, epochs = 30, pixel_size=pixel_size)
t2 = time.time()
print(f"経過時間:{t2-t1}")
Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
encoder_input (InputLayer)      (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 14, 14, 32)   320         encoder_input[0][0]              
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 7, 7, 64)     18496       conv2d_3[0][0]                   
__________________________________________________________________________________________________
flatten_2 (Flatten)             (None, 3136)         0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 16)           50192       flatten_2[0][0]                  
__________________________________________________________________________________________________
z_mean (Dense)                  (None, 2)            34          dense_3[0][0]                    
__________________________________________________________________________________________________
z_log_var (Dense)               (None, 2)            34          dense_3[0][0]                    
__________________________________________________________________________________________________
z (Lambda)                      (None, 2)            0           z_mean[0][0]                     
                                                                 z_log_var[0][0]                  
==================================================================================================
Total params: 69,076
Trainable params: 69,076
Non-trainable params: 0
__________________________________________________________________________________________________
Model: "decoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
z_sampling (InputLayer)      (None, 2)                 0         
_________________________________________________________________
dense_4 (Dense)              (None, 3136)              9408      
_________________________________________________________________
reshape_2 (Reshape)          (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 14, 14, 64)        36928     
_________________________________________________________________
conv2d_transpose_4 (Conv2DTr (None, 28, 28, 32)        18464     
_________________________________________________________________
decoder_output (Conv2DTransp (None, 28, 28, 1)         289       
=================================================================
Total params: 65,089
Trainable params: 65,089
Non-trainable params: 0
_________________________________________________________________
Model: "vae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
encoder_input (InputLayer)   (None, 28, 28, 1)         0         
_________________________________________________________________
encoder (Model)              [(None, 2), (None, 2), (N 69076     
_________________________________________________________________
decoder (Model)              (None, 28, 28, 1)         65089     
=================================================================
Total params: 134,165
Trainable params: 134,165
Non-trainable params: 0
_________________________________________________________________
C:\Users\chkar\Anaconda3\envs\VAE2019_2\lib\site-packages\keras\engine\training_utils.py:819: UserWarning: Output decoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to decoder.
  'be expecting any data to be passed to {0}.'.format(name))
Train on 60000 samples, validate on 10000 samples
Epoch 1/30
60000/60000 [==============================] - 47s 784us/step - loss: 193.6665 - val_loss: 172.0591
Epoch 2/30
60000/60000 [==============================] - 47s 789us/step - loss: 167.4011 - val_loss: 164.4291
Epoch 3/30
60000/60000 [==============================] - 47s 783us/step - loss: 162.2286 - val_loss: 158.9856
Epoch 4/30
60000/60000 [==============================] - 53s 886us/step - loss: 159.4374 - val_loss: 158.0446
Epoch 5/30
60000/60000 [==============================] - 55s 921us/step - loss: 157.6928 - val_loss: 157.4837
Epoch 6/30
60000/60000 [==============================] - 53s 877us/step - loss: 156.2346 - val_loss: 154.9563
Epoch 7/30
60000/60000 [==============================] - 55s 914us/step - loss: 155.0817 - val_loss: 153.7993
Epoch 8/30
60000/60000 [==============================] - 54s 907us/step - loss: 153.9460 - val_loss: 154.8155
Epoch 9/30
60000/60000 [==============================] - 53s 882us/step - loss: 153.0785 - val_loss: 151.6348
Epoch 10/30
60000/60000 [==============================] - 65s 1ms/step - loss: 152.1906 - val_loss: 152.2074
Epoch 11/30
60000/60000 [==============================] - 63s 1ms/step - loss: 151.5213 - val_loss: 151.1470
Epoch 12/30
60000/60000 [==============================] - 58s 967us/step - loss: 150.9044 - val_loss: 150.5482
Epoch 13/30
60000/60000 [==============================] - 53s 880us/step - loss: 150.3266 - val_loss: 151.6687
Epoch 14/30
60000/60000 [==============================] - 52s 874us/step - loss: 149.7864 - val_loss: 150.8456
Epoch 15/30
60000/60000 [==============================] - 53s 879us/step - loss: 149.2886 - val_loss: 149.8152
Epoch 16/30
60000/60000 [==============================] - 55s 924us/step - loss: 148.9666 - val_loss: 149.1268
Epoch 17/30
60000/60000 [==============================] - 56s 930us/step - loss: 148.5700 - val_loss: 151.2642
Epoch 18/30
60000/60000 [==============================] - 54s 900us/step - loss: 148.2272 - val_loss: 152.6874
Epoch 19/30
60000/60000 [==============================] - 53s 887us/step - loss: 147.8739 - val_loss: 149.0919
Epoch 20/30
60000/60000 [==============================] - 49s 821us/step - loss: 147.5699 - val_loss: 148.8338
Epoch 21/30
60000/60000 [==============================] - 51s 850us/step - loss: 147.3039 - val_loss: 149.0276
Epoch 22/30
60000/60000 [==============================] - 52s 866us/step - loss: 147.0779 - val_loss: 148.4948
Epoch 23/30
60000/60000 [==============================] - 48s 803us/step - loss: 146.8090 - val_loss: 148.0245
Epoch 24/30
60000/60000 [==============================] - 50s 838us/step - loss: 146.5394 - val_loss: 148.0376
Epoch 25/30
60000/60000 [==============================] - 53s 889us/step - loss: 146.3152 - val_loss: 147.5412
Epoch 26/30
60000/60000 [==============================] - 51s 849us/step - loss: 146.1501 - val_loss: 147.7051
Epoch 27/30
60000/60000 [==============================] - 53s 886us/step - loss: 145.9010 - val_loss: 147.0912
Epoch 28/30
60000/60000 [==============================] - 51s 844us/step - loss: 145.7388 - val_loss: 146.6814
Epoch 29/30
60000/60000 [==============================] - 50s 832us/step - loss: 145.5794 - val_loss: 146.9672
Epoch 30/30
60000/60000 [==============================] - 50s 833us/step - loss: 145.3560 - val_loss: 147.1964
経過時間:1586.0493321418762
In [10]:
plot_latent(decoder)
In [11]:
z_mean, _, _ = encoder.predict(x_train2.reshape((x_train2.shape[0],pixel_size,pixel_size,1)))

plt.figure(figsize=(6, 5))
plt.scatter(z_mean[:, 0], z_mean[:, 1], s=3, c=y_train2, cmap=cmap)
plt.colorbar()
Out[11]:
<matplotlib.colorbar.Colorbar at 0x1bdc2663708>
In [ ]: